diff --git a/.asf.yaml b/.asf.yaml index 99fd6fac22c76..47a18d13cbca0 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -51,12 +51,20 @@ github: main: required_pull_request_reviews: required_approving_review_count: 1 + required_status_checks: + contexts: + - "Check License Header" + - "Use prettier to check formatting of documents" + - "Check Markdown Links" + - "Validate required_status_checks in .asf.yaml" + - "Spell Check with Typos" + - "Circular Dependency Check" + - "Detect Unused Dependencies" # needs to be updated as part of the release process # .asf.yaml doesn't support wildcard branch protection rules, only exact branch names # https://github.com/apache/infrastructure-asfyaml?tab=readme-ov-file#branch-protection - # Keeping set of protected branches for future releases - # Meanwhile creating a prerelease script that will update the branch protection names - # automatically. Keep track on it https://github.com/apache/datafusion/issues/17134 + # these branches protection blocks autogenerated during release process which is described in + # https://github.com/apache/datafusion/tree/main/dev/release#2-add-a-protection-to-release-candidate-branch branch-50: required_pull_request_reviews: required_approving_review_count: 1 @@ -66,66 +74,15 @@ github: branch-52: required_pull_request_reviews: required_approving_review_count: 1 - branch-53: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-54: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-55: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-56: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-57: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-58: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-59: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-60: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-61: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-62: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-63: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-64: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-65: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-66: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-67: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-68: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-69: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-70: - required_pull_request_reviews: - required_approving_review_count: 1 pull_requests: # enable updating head branches of pull requests allow_update_branch: true allow_auto_merge: true + # auto-delete head branches after being merged + del_branch_on_merge: true # publishes the content of the `asf-site` branch to # https://datafusion.apache.org/ publish: whoami: asf-site + diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 22d2f2187dd07..6228370c955a9 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -46,3 +46,17 @@ runs: # https://github.com/actions/checkout/issues/766 shell: bash run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + - name: Remove unnecessary preinstalled software + shell: bash + run: | + echo "Disk space before cleanup:" + df -h + apt-get clean + # remove tool cache: about 8.5GB (github has host /opt/hostedtoolcache mounted as /__t) + rm -rf /__t/* || true + # remove Haskell runtime: about 6.3GB (host /usr/local/.ghcup) + rm -rf /host/usr/local/.ghcup || true + # remove Android library: about 7.8GB (host /usr/local/lib/android) + rm -rf /host/usr/local/lib/android || true + echo "Disk space after cleanup:" + df -h \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9d1d77d44c378..2cd4bdfdd7923 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -23,6 +23,7 @@ updates: interval: weekly target-branch: main labels: [auto-dependencies] + open-pull-requests-limit: 15 ignore: # major version bumps of arrow* and parquet are handled manually - dependency-name: "arrow*" @@ -44,10 +45,27 @@ updates: patterns: - "prost*" - "pbjson*" + + # Catch-all: group only minor/patch into a single PR, + # excluding deps we want always separate (and excluding arrow/parquet which have their own group) + all-other-cargo-deps: + applies-to: version-updates + patterns: + - "*" + exclude-patterns: + - "arrow*" + - "parquet" + - "object_store" + - "sqlparser" + - "prost*" + - "pbjson*" + update-types: + - "minor" + - "patch" - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" open-pull-requests-limit: 10 labels: [auto-dependencies] - package-ecosystem: "pip" diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index f0a03d9841a9d..b7afdf3c1914d 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -33,17 +33,22 @@ on: paths: - "**/Cargo.toml" - "**/Cargo.lock" - + merge_group: +permissions: + contents: read + jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-audit - uses: taiki-e/install-action@f535147c22906d77695e11cb199e764aa610a4fc # v2.62.46 + uses: taiki-e/install-action@481c34c1cf3a84c68b5e46f4eccfc82af798415a # v2.75.23 with: tool: cargo-audit - name: Run audit check + # Note: you can ignore specific RUSTSEC issues using the `--ignore` flag ,for example: + # run: cargo audit --ignore RUSTSEC-2026-0001 run: cargo audit diff --git a/.github/workflows/breaking_changes_detector.yml b/.github/workflows/breaking_changes_detector.yml new file mode 100644 index 0000000000000..4a4c909e7a781 --- /dev/null +++ b/.github/workflows/breaking_changes_detector.yml @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Detect semver-incompatible (breaking) API changes in crates modified by a PR. +# +# Only public workspace crates that have file changes are checked. +# Internal crates (benchmarks, test-utils, sqllogictest, doc) are excluded. +# +# This workflow only runs cargo-semver-checks and uploads the result as an +# artifact. The actual PR comment is posted by a companion workflow +# (`breaking_changes_detector_comment.yml`) that picks up the artifact via +# `workflow_run`. +# +# Why split it? +# "The GITHUB_TOKEN has read-only permissions in pull requests from forked +# repositories." +# https://docs.github.com/en/actions/reference/events-that-trigger-workflows#pull_request +# A read-only token cannot post comments, so on fork PRs the previous +# single-workflow design failed with HTTP 403. We can't simply broaden the +# trigger here either: cargo-semver-checks compiles PR code (build.rs, proc +# macros), so granting this job a write token would expose it to any code +# in the PR. And ASF infra policy independently forbids `pull_request_target` +# for any workflow that exposes GITHUB_TOKEN +# (https://infra.apache.org/github-actions-policy.html). The companion +# `workflow_run` workflow runs in the base-repo context with write access +# and never executes PR code. + +name: "Detect breaking changes" + +on: + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + check-semver: + name: Check semver + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + # For fork PRs, `origin` points to the fork, not the upstream repo. + # Explicitly fetch the base branch from the upstream repo so we have + # a valid baseline ref for both diff and semver-checks. + - name: Fetch base branch + env: + BASE_REF: ${{ github.base_ref }} + REPO: ${{ github.repository }} + run: git fetch "https://github.com/${REPO}.git" "${BASE_REF}:refs/remotes/origin/${BASE_REF}" + + - name: Determine changed crates + id: changed_crates + env: + BASE_REF: ${{ github.base_ref }} + run: | + PACKAGES=$(ci/scripts/changed_crates.sh changed-crates "origin/${BASE_REF}") + echo "packages=$PACKAGES" >> "$GITHUB_OUTPUT" + echo "Changed crates: $PACKAGES" + + # `datafusion-substrait` (and crates that depend on it via sqllogictest) + # have a build script that calls protoc, which is not preinstalled on + # ubuntu-latest runners. + - name: Install Protobuf Compiler + if: steps.changed_crates.outputs.packages != '' + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler + + - name: Install cargo-semver-checks + if: steps.changed_crates.outputs.packages != '' + uses: taiki-e/install-action@94cb46f8d6e437890146ffbd78a778b78e623fb2 # v2.74.0 + with: + tool: cargo-semver-checks + + - name: Run cargo-semver-checks + id: check_semver + if: steps.changed_crates.outputs.packages != '' + env: + BASE_REF: ${{ github.base_ref }} + PACKAGES: ${{ steps.changed_crates.outputs.packages }} + run: | + set +e + # `tee` lets cargo's output stream live into the Actions log + # while we also keep a copy for the PR comment. + ci/scripts/changed_crates.sh semver-check "origin/${BASE_REF}" $PACKAGES \ + 2>&1 | tee /tmp/semver-output.txt + EXIT_CODE=${PIPESTATUS[0]} + # Pass the result through an output instead of failing the job: + # a detected breaking change should surface as a PR comment, not a + # red check, so PR authors aren't confused by an intentional break. + if [ "$EXIT_CODE" -eq 0 ]; then + echo "result=success" >> "$GITHUB_OUTPUT" + else + echo "result=failure" >> "$GITHUB_OUTPUT" + fi + + # Stage the data the companion comment workflow needs into a single + # directory. We default the result to "success" so the comment + # workflow clears any stale comment when the check step is skipped + # (e.g. no published crates changed). + - name: Stage artifact for comment workflow + if: always() + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + CHECK_RESULT: ${{ steps.check_semver.outputs.result || 'success' }} + run: | + mkdir -p semver-artifact + echo "$PR_NUMBER" > semver-artifact/pr_number + echo "$CHECK_RESULT" > semver-artifact/result + if [ -f /tmp/semver-output.txt ]; then + sed 's/\x1b\[[0-9;]*m//g' /tmp/semver-output.txt > semver-artifact/logs + else + : > semver-artifact/logs + fi + + - name: Upload artifact + if: always() + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: semver-check-result + path: semver-artifact/ + retention-days: 1 diff --git a/.github/workflows/breaking_changes_detector_comment.yml b/.github/workflows/breaking_changes_detector_comment.yml new file mode 100644 index 0000000000000..8e79426082557 --- /dev/null +++ b/.github/workflows/breaking_changes_detector_comment.yml @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Companion to `breaking_changes_detector.yml`. Posts the sticky PR comment. +# +# Why this workflow exists: +# "The GITHUB_TOKEN has read-only permissions in pull requests from forked +# repositories." +# https://docs.github.com/en/actions/reference/events-that-trigger-workflows#pull_request +# That is why the upstream `pull_request` workflow cannot post the comment +# itself when the PR comes from a fork. +# +# Why not `pull_request_target`? ASF infra policy forbids it: +# "You MUST NOT use `pull_request_target` as a trigger on ANY action that +# exports ANY confidential credentials or tokens such as GITHUB_TOKEN or +# NPM_TOKEN." +# https://infra.apache.org/github-actions-policy.html +# `workflow_run` is the supported alternative: it runs in the base +# repository's context regardless of where the upstream run was triggered +# from, so the GITHUB_TOKEN here can be granted `pull-requests: write`. See: +# https://docs.github.com/en/actions/reference/events-that-trigger-workflows#workflow_run +# +# Security note: this workflow MUST NOT check out or execute any code from +# the PR. The artifact's contents originate from a workflow run that may +# have compiled fork-controlled code, so PR_NUMBER and CHECK_RESULT are +# validated against strict patterns before being passed to any action. + +name: "Detect breaking changes - Comment" + +on: + workflow_run: + workflows: ["Detect breaking changes"] + types: + - completed + +permissions: + contents: read + +jobs: + comment-on-pr: + name: Comment on pull request + if: github.event.workflow_run.event == 'pull_request' + runs-on: ubuntu-latest + # Scoped to the minimum needed to upsert/delete the sticky comment. + permissions: + actions: read + pull-requests: write + steps: + - name: Download semver-check artifact + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: semver-check-result + run-id: ${{ github.event.workflow_run.id }} + github-token: ${{ github.token }} + path: ./semver-artifact + + - name: Read and validate artifact + id: read + run: | + set -euo pipefail + # Validate every field: the artifact comes from a workflow run + # that compiled fork-controlled code, so its contents are untrusted. + PR_NUMBER=$(cat ./semver-artifact/pr_number) + if ! [[ "$PR_NUMBER" =~ ^[0-9]+$ ]]; then + echo "Invalid PR number: $PR_NUMBER" >&2 + exit 1 + fi + CHECK_RESULT=$(cat ./semver-artifact/result) + if [[ "$CHECK_RESULT" != "success" && "$CHECK_RESULT" != "failure" ]]; then + echo "Invalid check result: $CHECK_RESULT" >&2 + exit 1 + fi + echo "pr_number=$PR_NUMBER" >> "$GITHUB_OUTPUT" + echo "result=$CHECK_RESULT" >> "$GITHUB_OUTPUT" + + # Multi-line output: random delimiter so a malicious log line can't + # close the heredoc and inject extra output keys. See: + # https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#multiline-strings + DELIM="EOF_$(openssl rand -hex 16)" + { + echo "logs<<${DELIM}" + cat ./semver-artifact/logs + echo "${DELIM}" + } >> "$GITHUB_OUTPUT" + + # The marker `` is what makes the comment + # "sticky": maintain-one-comment uses it to find and replace (or + # delete) the existing comment instead of stacking new ones. + - name: Upsert sticky comment + if: steps.read.outputs.result != 'success' + uses: actions-cool/maintain-one-comment@909842216bc8e8658364c572ec52100f4c2cc50a # v3.3.0 + with: + token: ${{ secrets.GITHUB_TOKEN }} + number: ${{ steps.read.outputs.pr_number }} + body-include: '' + body: | + + Thank you for opening this pull request! + + Reviewer note: [cargo-semver-checks](https://github.com/obi1kenobi/cargo-semver-checks) reported the current version number is not SemVer-compatible with the changes in this pull request (compared against the base branch). + +
+ Details + + ``` + ${{ steps.read.outputs.logs }} + ``` + +
+ + - name: Delete sticky comment + if: steps.read.outputs.result == 'success' + uses: actions-cool/maintain-one-comment@909842216bc8e8658364c572ec52100f4c2cc50a # v3.3.0 + with: + token: ${{ secrets.GITHUB_TOKEN }} + number: ${{ steps.read.outputs.pr_number }} + body-include: '' + delete: true diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000000..70d38b28112de --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: '16 4 * * 1' + +permissions: + contents: read + +jobs: + analyze: + name: Analyze Actions + runs-on: ubuntu-slim + permissions: + contents: read + security-events: write + packages: read + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + persist-credentials: false + + - name: Initialize CodeQL + uses: github/codeql-action/init@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4 + with: + languages: actions + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4 + with: + category: "/language:actions" diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index 7e736e1a7afbf..2f3a127ef98c4 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -25,26 +25,23 @@ on: push: branches-ignore: - 'gh-readonly-queue/**' - paths: - - "**/Cargo.toml" - - "**/Cargo.lock" pull_request: - paths: - - "**/Cargo.toml" - - "**/Cargo.lock" merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: +permissions: + contents: read + jobs: depcheck: - name: circular dependency check + name: Circular Dependency Check runs-on: ubuntu-latest container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -58,12 +55,13 @@ jobs: cargo run detect-unused-dependencies: + name: Detect Unused Dependencies runs-on: ubuntu-latest container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-machete run: cargo install cargo-machete --version ^0.9 --locked - name: Detect unused dependencies - run: cargo machete --with-metadata \ No newline at end of file + run: cargo machete --with-metadata diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc879f66cc936..376e68bcd5621 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -23,6 +23,9 @@ on: pull_request: merge_group: +permissions: + contents: read + concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} cancel-in-progress: true @@ -32,28 +35,60 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install HawkEye + # This CI job is bound by installation time, use `--profile dev` to speed it up run: cargo install hawkeye --version 6.2.0 --locked --profile dev - name: Run license header check run: ci/scripts/license_header.sh prettier: name: Use prettier to check formatting of documents - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Prettier check + # if you encounter error, see instructions inside the script + run: ci/scripts/doc_prettier_check.sh + + markdown-link-check: + name: Check Markdown Links + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Load tool versions run: | - # if you encounter error, rerun the command below and commit the changes - # - # ignore subproject CHANGELOG.md because they are machine generated - npx prettier@2.7.1 --write \ - '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ - '!datafusion/CHANGELOG.md' \ - README.md \ - CONTRIBUTING.md - git diff --exit-code + source ci/scripts/utils/tool_versions.sh + echo "LYCHEE_VERSION=${LYCHEE_VERSION}" >> "$GITHUB_ENV" + - name: Install lychee + uses: taiki-e/install-action@481c34c1cf3a84c68b5e46f4eccfc82af798415a # v2.75.23 + with: + tool: lychee@${{ env.LYCHEE_VERSION }} + - name: Run markdown link check + run: bash ci/scripts/markdown_link_check.sh + + asf-yaml-check: + name: Validate required_status_checks in .asf.yaml + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - run: pip install pyyaml + - run: python3 ci/scripts/check_asf_yaml_status_checks.py + + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + # Version fixed on purpose. It uses heuristics to detect typos, so upgrading + # it may cause checks to fail more often. + # We can upgrade it manually once a while. + - name: Install typos-cli + run: cargo install typos-cli --locked --version 1.37.0 + - name: Run typos check + run: ci/scripts/typos_check.sh diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 588bf46aaca70..f0fbea566af69 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -28,36 +28,37 @@ name: Deploy DataFusion site jobs: build-docs: + permissions: + contents: write name: Build docs runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout asf-site branch - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: asf-site path: asf-site - - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Install dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs run: | set -x - source venv/bin/activate cd docs - ./build.sh + uv run --package datafusion-docs ./build.sh - name: Copy & push the generated HTML run: | diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index c182f2ef85d23..4b8d25b0611eb 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -33,31 +33,31 @@ on: # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: +permissions: + contents: read + jobs: - # Test doc build linux-test-doc-build: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Install doc dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs html and check for warnings run: | set -x - source venv/bin/activate cd docs - ./build.sh # fails on errors - + uv run --package datafusion-docs ./build.sh # fails on errors diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 2472d2e0424fd..5776aed31b761 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -44,15 +44,10 @@ on: - 'datafusion/physical*/**/*.rs' - 'datafusion/expr*/**/*.rs' - 'datafusion/optimizer/**/*.rs' + - 'datafusion/sql/**/*.rs' - 'datafusion-testing' workflow_dispatch: inputs: - pr_number: - description: 'Pull request number' - type: string - check_run_id: - description: 'Check run ID for status updates' - type: string pr_head_sha: description: 'PR head SHA' type: string @@ -66,10 +61,11 @@ jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -80,7 +76,9 @@ jobs: source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler - name: Prepare cargo build run: | cargo check --profile ci --all-targets @@ -90,10 +88,11 @@ jobs: linux-test-extended: name: cargo test 'extended_tests' (amd64) needs: [linux-build-lib] - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=32,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -106,7 +105,9 @@ jobs: source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler # For debugging, test binaries can be large. - name: Show available disk space run: | @@ -124,7 +125,7 @@ jobs: --lib \ --tests \ --bins \ - --features avro,json,backtrace,extended_tests,recursive_protection + --features avro,json,backtrace,extended_tests,recursive_protection,parquet_encryption - name: Verify Working Directory Clean run: git diff --exit-code - name: Cleanup @@ -133,11 +134,12 @@ jobs: # Check answers are correct when hash values collide hash-collisions: name: cargo test hash collisions (amd64) - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -154,24 +156,20 @@ jobs: sqllogictest-sqlite: name: "Run sqllogictests with the sqlite test suite" - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=32,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable + # Don't use setup-builder to avoid configuring RUST_BACKTRACE which is expensive + - name: Install protobuf compiler + run: | + apt-get update && apt-get install -y protobuf-compiler - name: Run sqllogictest run: | - cargo test --features backtrace --profile release-nonlto --test sqllogictests -- --include-sqlite - cargo clean - - - - + cargo test --features backtrace,parquet_encryption --profile ci-optimized --test sqllogictests -- --include-sqlite \ No newline at end of file diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0abf535b9741f..2d42d6ff964e8 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -31,7 +31,7 @@ on: jobs: process: name: Process - runs-on: ubuntu-latest + runs-on: ubuntu-slim # only run for users whose permissions allow them to update PRs # otherwise labeler is failing: # https://github.com/apache/datafusion/issues/3743 @@ -39,8 +39,6 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && diff --git a/.github/workflows/labeler/labeler-config.yml b/.github/workflows/labeler/labeler-config.yml index 38d88059dab70..0e492b6f3f6dc 100644 --- a/.github/workflows/labeler/labeler-config.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -62,7 +62,7 @@ datasource: functions: - changed-files: - - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common', 'datafusion/functions-nested', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] + - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common/**/*', 'datafusion/functions-nested/**/*', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] optimizer: diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index 9cbfd6030a7f6..5a127e443fcb7 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -25,18 +25,21 @@ on: pull_request: merge_group: +permissions: + contents: read + jobs: check-files: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Check size of new Git objects env: - # 1 MB ought to be enough for anybody. + # 1.5 MB ought to be enough for anybody. # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label - MAX_FILE_SIZE_BYTES: 1048576 + MAX_FILE_SIZE_BYTES: 1572864 shell: bash run: | if [ "${{ github.event_name }}" = "merge_group" ]; then diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 88d9f4e13378c..3f6462f0f01c1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +# For some actions, we use Runs-On to run them on ASF infrastructure: https://datafusion.apache.org/contributor-guide/#ci-runners + name: Rust concurrency: @@ -36,26 +38,29 @@ on: - "**.md" - ".github/ISSUE_TEMPLATE/**" - ".github/pull_request_template.md" - merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: +permissions: + contents: read + jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: shared-key: "amd-ci-check" # this job uses it's own cache becase check has a separate cache and we need it to be fast as it blocks other jobs save-if: ${{ github.ref_name == 'main' }} @@ -77,7 +82,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -98,17 +103,17 @@ jobs: linux-datafusion-substrait-features: name: cargo check datafusion-substrait features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -135,11 +140,12 @@ jobs: linux-datafusion-proto-features: name: cargo check datafusion-proto features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -166,17 +172,18 @@ jobs: linux-cargo-check-datafusion: name: cargo check datafusion features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -209,8 +216,6 @@ jobs: run: cargo check --profile ci --no-default-features -p datafusion --features=math_expressions - name: Check datafusion (parquet) run: cargo check --profile ci --no-default-features -p datafusion --features=parquet - - name: Check datafusion (pyarrow) - run: cargo check --profile ci --no-default-features -p datafusion --features=pyarrow - name: Check datafusion (regex_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=regex_expressions - name: Check datafusion (recursive_protection) @@ -237,7 +242,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -268,11 +273,14 @@ jobs: linux-test: name: cargo test (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust + volumes: + - /usr/local:/host/usr/local steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -281,7 +289,7 @@ jobs: with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci" @@ -299,7 +307,7 @@ jobs: --lib \ --tests \ --bins \ - --features serde,avro,json,backtrace,integration-tests,parquet_encryption + --features serde,avro,json,backtrace,integration-tests,parquet_encryption,substrait - name: Verify Working Directory Clean run: git diff --exit-code # Check no temporary directories created during test. @@ -316,16 +324,17 @@ jobs: linux-test-datafusion-cli: name: cargo test datafusion-cli (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain run: rustup toolchain install stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -345,11 +354,12 @@ jobs: linux-test-example: name: cargo examples (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -358,23 +368,10 @@ jobs: with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci-linux-test-example" - - name: Remove unnecessary preinstalled software - run: | - echo "Disk space before cleanup:" - df -h - apt-get clean - rm -rf /__t/CodeQL - rm -rf /__t/PyPy - rm -rf /__t/Java_Temurin-Hotspot_jdk - rm -rf /__t/Python - rm -rf /__t/go - rm -rf /__t/Ruby - echo "Disk space after cleanup:" - df -h - name: Run examples run: | # test datafusion-sql examples @@ -388,11 +385,12 @@ jobs: linux-test-doc: name: cargo test doc (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -409,11 +407,12 @@ jobs: linux-rustdoc: name: cargo doc needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -425,7 +424,7 @@ jobs: name: build and run with wasm-pack runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup for wasm32 run: | rustup target add wasm32-unknown-unknown @@ -434,7 +433,7 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - uses: taiki-e/install-action@f535147c22906d77695e11cb199e764aa610a4fc # v2.62.46 + uses: taiki-e/install-action@481c34c1cf3a84c68b5e46f4eccfc82af798415a # v2.75.23 with: tool: wasm-pack - name: Run tests with headless mode @@ -449,11 +448,12 @@ jobs: verify-benchmark-results: name: verify benchmark results (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -475,14 +475,14 @@ jobs: export RUST_MIN_STACK=20971520 export TPCH_DATA=`realpath datafusion/sqllogictest/test_files/tpch/data` cargo test plan_q --package datafusion-benchmarks --profile ci --features=ci -- --test-threads=1 - INCLUDE_TPCH=true cargo test --features backtrace --profile ci --package datafusion-sqllogictest --test sqllogictests + INCLUDE_TPCH=true cargo test --features backtrace,parquet_encryption,substrait --profile ci --package datafusion-sqllogictest --test sqllogictests - name: Verify Working Directory Clean run: git diff --exit-code sqllogictest-postgres: name: "Run sqllogictest with Postgres runner" needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust services: @@ -500,7 +500,8 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -520,11 +521,12 @@ jobs: sqllogictest-substrait: name: "Run sqllogictest in Substrait round-trip mode" needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -537,7 +539,7 @@ jobs: # command cannot be run for all the .slt files. Run it for just one that works (limit.slt) # until most of the tickets in https://github.com/apache/datafusion/issues/16248 are addressed # and this command can be run without filters. - run: cargo test --test sqllogictests -- --substrait-round-trip limit.slt + run: cargo test -p datafusion-sqllogictest --test sqllogictests --features substrait -- --substrait-round-trip limit.slt # Temporarily commenting out the Windows flow, the reason is enormously slow running build # Waiting for new Windows 2025 github runner @@ -560,9 +562,9 @@ jobs: macos-aarch64: name: cargo test (macos-aarch64) - runs-on: macos-14 + runs-on: macos-15 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -570,31 +572,7 @@ jobs: uses: ./.github/actions/setup-macos-aarch64-builder - name: Run tests (excluding doctests) shell: bash - run: cargo test --profile ci --exclude datafusion-cli --workspace --lib --tests --bins --features avro,json,backtrace,integration-tests - - test-datafusion-pyarrow: - name: cargo test pyarrow (amd64) - needs: linux-build-lib - runs-on: ubuntu-latest - container: - image: amd64/rust:bullseye # Use the bullseye tag image which comes with python3.9 - steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - submodules: true - fetch-depth: 1 - - name: Install PyArrow - run: | - echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV - apt-get update - apt-get install python3-pip -y - python3 -m pip install pyarrow - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run datafusion-common tests - run: cargo test --profile ci -p datafusion-common --features=pyarrow,sql + run: cargo test --profile ci --exclude datafusion-cli --workspace --lib --tests --bins --features avro,json,backtrace,integration-tests,substrait vendor: name: Verify Vendored Code @@ -602,7 +580,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -619,7 +597,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -674,11 +652,12 @@ jobs: clippy: name: clippy needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -689,7 +668,7 @@ jobs: - name: Install Clippy run: rustup component add clippy - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci-clippy" @@ -703,7 +682,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -720,11 +699,12 @@ jobs: config-docs-check: name: check configs.md and ***_functions.md is up-to-date needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: runs-on/action@742bf56072eb4845a0f94b3394673e4903c90ff0 # v2.1.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -732,7 +712,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -746,6 +726,38 @@ jobs: ./dev/update_function_docs.sh git diff --exit-code +# This job ensures `datafusion-examples/README.md` stays in sync with the source code: +# 1. Generates README automatically using the Rust examples docs generator +# (parsing documentation from `examples//main.rs`) +# 2. Formats the generated Markdown using DataFusion's standard Prettier setup +# 3. Compares the result against the committed README.md and fails if out-of-date + examples-docs-check: + name: check example README is up-to-date + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + submodules: true + fetch-depth: 1 + + - name: Mark repository as safe for git + # Required for git commands inside container (avoids "dubious ownership" error) + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Set up Node.js (required for prettier) + # doc_prettier_check.sh uses npx to run prettier for Markdown formatting + uses: actions/setup-node@v6 + with: + node-version: '18' + + - name: Run examples docs check script + run: | + bash ci/scripts/check_examples_docs.sh + # Verify MSRV for the crates which are directly used by other projects: # - datafusion # - datafusion-substrait @@ -757,11 +769,11 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - uses: taiki-e/install-action@f535147c22906d77695e11cb199e764aa610a4fc # v2.62.46 + uses: taiki-e/install-action@481c34c1cf3a84c68b5e46f4eccfc82af798415a # v2.75.23 with: tool: cargo-msrv @@ -799,11 +811,3 @@ jobs: - name: Check datafusion-proto working-directory: datafusion/proto run: cargo msrv --output-format json --log-target stdout verify - typos: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - persist-credentials: false - - uses: crate-ci/typos@07d900b8fa1097806b8adb6391b0d3e0ac2fdea7 # v1.39.0 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index d5fc9287aa6a5..8627b3bf044ff 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -22,12 +22,13 @@ on: jobs: close-stale-prs: - runs-on: ubuntu-latest + runs-on: ubuntu-slim permissions: + actions: write issues: write pull-requests: write steps: - - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 @@ -36,3 +37,4 @@ jobs: days-before-issue-stale: -1 days-before-issue-close: -1 repo-token: ${{ secrets.GITHUB_TOKEN }} + operations-per-run: 150 diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index 86dc190add1d1..e34bf869ef8a0 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -name: Assign the issue via a `take` comment +name: Assign/unassign the issue via `take` or `untake` comment on: issue_comment: types: created @@ -25,17 +25,31 @@ permissions: jobs: issue_assign: - runs-on: ubuntu-latest - if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' + runs-on: ubuntu-slim + if: (!github.event.issue.pull_request) && (github.event.comment.body == 'take' || github.event.comment.body == 'untake') concurrency: group: ${{ github.actor }}-issue-assign steps: - - run: | - CODE=$(curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -LI https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees/${{ github.event.comment.user.login }} -o /dev/null -w '%{http_code}\n' -s) - if [ "$CODE" -eq "204" ] + - name: Take or untake issue + env: + COMMENT_BODY: ${{ github.event.comment.body }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + USER_LOGIN: ${{ github.event.comment.user.login }} + REPO: ${{ github.repository }} + TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if [ "$COMMENT_BODY" == "take" ] then - echo "Assigning issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" - curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -d '{"assignees": ["${{ github.event.comment.user.login }}"]}' https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees - else - echo "Cannot assign issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" + CODE=$(curl -H "Authorization: token $TOKEN" -LI https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees/$USER_LOGIN -o /dev/null -w '%{http_code}\n' -s) + if [ "$CODE" -eq "204" ] + then + echo "Assigning issue $ISSUE_NUMBER to $USER_LOGIN" + curl -X POST -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees + else + echo "Cannot assign issue $ISSUE_NUMBER to $USER_LOGIN" + fi + elif [ "$COMMENT_BODY" == "untake" ] + then + echo "Unassigning issue $ISSUE_NUMBER from $USER_LOGIN" + curl -X DELETE -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees fi \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8466a72adaec8..c1f9677e47366 100644 --- a/.gitignore +++ b/.gitignore @@ -75,3 +75,9 @@ rat.txt # data generated by examples datafusion-examples/examples/datafusion-examples/ + +# Samply profile data +profile.json.gz + +# Claude Code personal settings +.claude/settings.local.json diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000000..9dff7f6f1ffd1 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,41 @@ +# Agent Guidelines for Apache DataFusion + +## Developer Documentation + +- [Quick Start Setup](docs/source/contributor-guide/development_environment.md#quick-start) +- [Testing Quick Start](docs/source/contributor-guide/testing.md#testing-quick-start) +- [Before Submitting a PR](docs/source/contributor-guide/index.md#before-submitting-a-pr) +- [Contributor Guide](docs/source/contributor-guide/index.md) +- [Architecture Guide](docs/source/contributor-guide/architecture.md) + +## Before Committing + +Before committing any changes, you MUST follow the instructions in +[Before Submitting a PR](docs/source/contributor-guide/index.md#before-submitting-a-pr) +and ensure the required checks listed there pass. Do not commit code that +fails any of those checks. + +At a minimum, you MUST run and fix any errors from these commands before +committing: + +```bash +# Format code +cargo fmt --all + +# Lint (must pass with no warnings) +cargo clippy --all-targets --all-features -- -D warnings +``` + +You can also run the full lint suite used by CI: + +```bash +./dev/rust_lint.sh +# or auto-fix: ./dev/rust_lint.sh --write --allow-dirty +``` + +When creating a PR, you MUST follow the [PR template](.github/pull_request_template.md). + +## Testing + +See the [Testing Quick Start](docs/source/contributor-guide/testing.md#testing-quick-start) +for the recommended pre-PR test commands. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 0000000000000..47dc3e3d863cf --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f500265108ff5..af52588e5338e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,71 +2,12 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "abi_stable" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d6512d3eb05ffe5004c59c206de7f99c34951504056ce23fc953842f12c445" -dependencies = [ - "abi_stable_derive", - "abi_stable_shared", - "const_panic", - "core_extensions", - "crossbeam-channel", - "generational-arena", - "libloading 0.7.4", - "lock_api", - "parking_lot", - "paste", - "repr_offset", - "rustc_version", - "serde", - "serde_derive", - "serde_json", -] - -[[package]] -name = "abi_stable_derive" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7178468b407a4ee10e881bc7a328a65e739f0863615cca4429d43916b05e898" -dependencies = [ - "abi_stable_shared", - "as_derive_utils", - "core_extensions", - "proc-macro2", - "quote", - "rustc_version", - "syn 1.0.109", - "typed-arena", -] - -[[package]] -name = "abi_stable_shared" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2b5df7688c123e63f4d4d649cba63f2967ba7f7861b1664fca3f77d3dad2b63" -dependencies = [ - "core_extensions", -] - [[package]] name = "adler2" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" -[[package]] -name = "ahash" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" -dependencies = [ - "getrandom 0.2.16", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.12" @@ -83,9 +24,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -105,6 +46,15 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -128,9 +78,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.20" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", "anstyle-parse", @@ -143,33 +93,33 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ "windows-sys 0.60.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.10" +version = "3.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", @@ -178,37 +128,17 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] -name = "apache-avro" -version = "0.20.0" +name = "ar_archive_writer" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a033b4ced7c585199fb78ef50fca7fe2f444369ec48080c5fd072efa1a03cc7" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" dependencies = [ - "bigdecimal", - "bon", - "bzip2 0.6.1", - "crc32fast", - "digest", - "log", - "miniz_oxide", - "num-bigint", - "quad-rand", - "rand 0.9.2", - "regex-lite", - "serde", - "serde_bytes", - "serde_json", - "snap", - "strum 0.27.2", - "strum_macros 0.27.2", - "thiserror", - "uuid", - "xz2", - "zstd", + "object", ] [[package]] @@ -225,9 +155,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df8bb5b0bd64c0b9bc61317fcc480bad0f00e56d3bc32c69a4c8dada4786bae" +checksum = "d441fdda254b65f3e9025910eb2c2066b6295d9c8ed409522b8d2ace1ff8574c" dependencies = [ "arrow-arith", "arrow-array", @@ -238,20 +168,19 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-ord", - "arrow-pyarrow", "arrow-row", "arrow-schema", "arrow-select", "arrow-string", "half", - "rand 0.9.2", + "rand 0.9.4", ] [[package]] name = "arrow-arith" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a640186d3bd30a24cb42264c2dafb30e236a6f50d510e56d40b708c9582491" +checksum = "ced5406f8b720cc0bc3aa9cf5758f93e8593cda5490677aa194e4b4b383f9a59" dependencies = [ "arrow-array", "arrow-buffer", @@ -263,28 +192,52 @@ dependencies = [ [[package]] name = "arrow-array" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219fe420e6800979744c8393b687afb0252b3f8a89b91027d27887b72aa36d31" +checksum = "772bd34cacdda8baec9418d80d23d0fb4d50ef0735685bd45158b83dfeb6e62d" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "chrono-tz", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "num-complex", "num-integer", "num-traits", ] +[[package]] +name = "arrow-avro" +version = "58.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36a3aadd016f63dfd4941ae8e13539ba98a3c2995adc3c88b9336d2514f6c8a7" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", + "bytes", + "bzip2", + "crc", + "flate2", + "indexmap 2.14.0", + "liblzma", + "rand 0.9.4", + "serde", + "serde_json", + "snap", + "strum_macros", + "uuid", + "zstd", +] + [[package]] name = "arrow-buffer" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76885a2697a7edf6b59577f568b456afc94ce0e2edc15b784ce3685b6c3c5c27" +checksum = "898f4cf1e9598fdb77f356fdf2134feedfd0ee8d5a4e0a5f573e7d0aec16baa4" dependencies = [ "bytes", "half", @@ -294,13 +247,14 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9ebb4c987e6b3b236fb4a14b20b34835abfdd80acead3ccf1f9bf399e1f168" +checksum = "b0127816c96533d20fc938729f48c52d3e48f99717e7a0b5ade77d742510736d" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", + "arrow-ord", "arrow-schema", "arrow-select", "atoi", @@ -315,9 +269,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92386159c8d4bce96f8bd396b0642a0d544d471bdc2ef34d631aec80db40a09c" +checksum = "ca025bd0f38eeecb57c2153c0123b960494138e6a957bbda10da2b25415209fe" dependencies = [ "arrow-array", "arrow-cast", @@ -330,9 +284,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727681b95de313b600eddc2a37e736dcb21980a40f640314dcf360e2f36bc89b" +checksum = "42d10beeab2b1c3bb0b53a00f7c944a178b622173a5c7bcabc3cb45d90238df4" dependencies = [ "arrow-buffer", "arrow-schema", @@ -343,9 +297,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f70bb56412a007b0cfc116d15f24dda6adeed9611a213852a004cda20085a3b9" +checksum = "302b2e036335f3f04d65dad3f74ff1f2aae6dc671d6aa04dc6b61193761e16fb" dependencies = [ "arrow-arith", "arrow-array", @@ -371,9 +325,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9ba92e3de170295c98a84e5af22e2b037f0c7b32449445e6c493b5fca27f27" +checksum = "609a441080e338147a84e8e6904b6da482cefb957c5cdc0f3398872f69a315d0" dependencies = [ "arrow-array", "arrow-buffer", @@ -387,9 +341,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b969b4a421ae83828591c6bf5450bd52e6d489584142845ad6a861f42fe35df8" +checksum = "6ead0914e4861a531be48fe05858265cf854a4880b9ed12618b1d08cba9bebc8" dependencies = [ "arrow-array", "arrow-buffer", @@ -398,7 +352,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.12.0", + "indexmap 2.14.0", "itoa", "lexical-core", "memchr", @@ -411,9 +365,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "141c05298b21d03e88062317a1f1a73f5ba7b6eb041b350015b1cd6aabc0519b" +checksum = "763a7ba279b20b52dad300e68cfc37c17efa65e68623169076855b3a9e941ca5" dependencies = [ "arrow-array", "arrow-buffer", @@ -422,23 +376,11 @@ dependencies = [ "arrow-select", ] -[[package]] -name = "arrow-pyarrow" -version = "57.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcfb2be2e9096236f449c11f425cddde18c4cc540f516d90f066f10a29ed515" -dependencies = [ - "arrow-array", - "arrow-data", - "arrow-schema", - "pyo3", -] - [[package]] name = "arrow-row" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f3c06a6abad6164508ed283c7a02151515cef3de4b4ff2cebbcaeb85533db2" +checksum = "e14fe367802f16d7668163ff647830258e6e0aeea9a4d79aaedf273af3bdcd3e" dependencies = [ "arrow-array", "arrow-buffer", @@ -449,11 +391,11 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cfa7a03d1eee2a4d061476e1840ad5c9867a544ca6c4c59256496af5d0a8be5" +checksum = "c30a1365d7a7dc50cc847e54154e6af49e4c4b0fddc9f607b687f29212082743" dependencies = [ - "bitflags 2.9.4", + "bitflags", "serde", "serde_core", "serde_json", @@ -461,11 +403,11 @@ dependencies = [ [[package]] name = "arrow-select" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bafa595babaad59f2455f4957d0f26448fb472722c186739f4fac0823a1bdb47" +checksum = "78694888660a9e8ac949853db393af2a8b8fc82c19ce333132dfa2e72cc1a7fe" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -475,9 +417,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f46457dbbb99f2650ff3ac23e46a929e0ab81db809b02aa5511c258348bef2" +checksum = "61e04a01f8bb73ce54437514c5fd3ee2aa3e8abe4c777ee5cc55853b1652f79e" dependencies = [ "arrow-array", "arrow-buffer", @@ -490,23 +432,11 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "as_derive_utils" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff3c96645900a44cf11941c111bd08a6573b0e2f9f69bc9264b179d8fae753c4" -dependencies = [ - "core_extensions", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "astral-tokio-tar" -version = "0.5.6" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec179a06c1769b1e42e1e2cbe74c7dcdb3d6383c838454d063eaac5bbb7ebbe5" +checksum = "4ce73b17c62717c4b6a9af10b43e87c578b0cac27e00666d48304d3b7d2c0693" dependencies = [ "filetime", "futures-core", @@ -520,19 +450,14 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.19" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06575e6a9673580f52661c92107baabffbf41e2141373441cbcdc47cb733003c" +checksum = "e79b3f8a79cccc2898f31920fc69f304859b3bd567490f75ebf51ae1c792a9ac" dependencies = [ - "bzip2 0.5.2", - "flate2", - "futures-core", - "memchr", + "compression-codecs", + "compression-core", "pin-project-lite", "tokio", - "xz2", - "zstd", - "zstd-safe", ] [[package]] @@ -540,9 +465,6 @@ name = "async-ffi" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4de21c0feef7e5a556e51af767c953f0501f7f300ba785cc99c47bdc8081a50" -dependencies = [ - "abi_stable", -] [[package]] name = "async-recursion" @@ -552,7 +474,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -574,7 +496,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -585,7 +507,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -611,9 +533,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.8.7" +version = "1.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04b37ddf8d2e9744a0b9c19ce0b78efe4795339a90b66b7bae77987092cd2e69" +checksum = "50f156acdd2cf55f5aa53ee416c4ac851cf1222694506c0b1f78c85695e9ca9d" dependencies = [ "aws-credential-types", "aws-runtime", @@ -630,8 +552,8 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.3.1", - "ring", + "http 1.4.0", + "sha1 0.10.6", "time", "tokio", "tracing", @@ -641,9 +563,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.7" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799a1290207254984cb7c05245111bc77958b92a3c9bb449598044b36341cce6" +checksum = "8f20799b373a1be121fe3005fba0c2090af9411573878f224df44b42727fcaf7" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -653,9 +575,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.14.0" +version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b8ff6c09cd57b16da53641caa860168b88c172a5ee163b0288d3d6eea12786" +checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" dependencies = [ "aws-lc-sys", "zeroize", @@ -663,11 +585,10 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.31.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e44d16778acaf6a9ec9899b92cebd65580b83f685446bf2e1f5d3d732f99dcd" +checksum = "f50037ee5e1e41e7b8f9d161680a725bd1626cb6f8c7e901f91f942850852fe7" dependencies = [ - "bindgen", "cc", "cmake", "dunce", @@ -676,9 +597,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.11" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e1ed337dabcf765ad5f2fb426f13af22d576328aaf09eac8f70953530798ec0" +checksum = "5dcd93c82209ac7413532388067dce79be5a8780c1786e5fae3df22e4dee2864" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -689,9 +610,10 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "bytes-utils", "fastrand", - "http 0.2.12", - "http-body 0.4.6", + "http 1.4.0", + "http-body 1.0.1", "percent-encoding", "pin-project-lite", "tracing", @@ -700,15 +622,16 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.85.0" +version = "1.98.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f2c741e2e439f07b5d1b33155e246742353d82167c785a2ff547275b7e32483" +checksum = "d69c77aafa20460c68b6b3213c84f6423b6e76dbf89accd3e1789a686ffd9489" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -716,21 +639,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.87.0" +version = "1.100.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6428ae5686b18c0ee99f6f3c39d94ae3f8b42894cdc35c35d8fb2470e9db2d4c" +checksum = "1c7e7b09346d5ca22a2a08267555843a6a0127fb20d8964cb6ecfb8fdb190225" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -738,21 +663,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.87.0" +version = "1.103.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5871bec9a79a3e8d928c7788d654f135dde0e71d2dd98089388bab36b37ef607" +checksum = "c2249b81a2e73a8027c41c378463a81ec39b8510f184f2caab87de912af0f49b" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -761,15 +688,16 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.4" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084c34162187d39e3740cb635acd73c4e3a551a36146ad6fe8883c929c9f876c" +checksum = "68dc0b907359b120170613b5c09ccc61304eac3998ff6274b97d93ee6490115a" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -780,7 +708,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "percent-encoding", "sha2", "time", @@ -789,9 +717,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.5" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" +checksum = "2ffcaf626bdda484571968400c326a244598634dc75fd451325a54ad1a59acfc" dependencies = [ "futures-util", "pin-project-lite", @@ -800,18 +728,19 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.3" +version = "0.63.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c4dacf2d38996cf729f55e7a762b30918229917eca115de45dfa8dfb97796c9" +checksum = "ba1ab2dc1c2c3749ead27180d333c42f11be8b0e934058fb4b2258ee8dbe5231" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", - "http 0.2.12", - "http 1.3.1", - "http-body 0.4.6", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", "percent-encoding", "pin-project-lite", "pin-utils", @@ -820,15 +749,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.1.1" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147e8eea63a40315d704b97bf9bc9b8c1402ae94f89d5ad6f7550d963309da1b" +checksum = "6a2f165a7feee6f263028b899d0a181987f4fa7179a6411a32a439fba7c5f769" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-rustls", "hyper-util", @@ -844,27 +773,27 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.5" +version = "0.62.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaa31b350998e703e9826b2104dd6f63be0508666e1aba88137af060e8944047" +checksum = "9648b0bb82a2eedd844052c6ad2a1a822d1f8e3adee5fbf668366717e428856a" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.3" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +checksum = "a06c2315d173edbf1920da8ba3a7189695827002e4c0fc961973ab1c54abca9c" dependencies = [ "aws-smithy-runtime-api", ] [[package]] name = "aws-smithy-query" -version = "0.60.7" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +checksum = "1a56d79744fb3edb5d722ef79d86081e121d3b9422cb209eb03aea6aa4f21ebd" dependencies = [ "aws-smithy-types", "urlencoding", @@ -872,9 +801,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.9.2" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fa63ad37685ceb7762fa4d73d06f1d5493feb88e3f27259b9ed277f4c01b185" +checksum = "0504b1ab12debb5959e5165ee5fe97dd387e7aa7ea6a477bfd7635dfe769a4f5" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -885,9 +814,10 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", + "http-body-util", "pin-project-lite", "pin-utils", "tokio", @@ -896,32 +826,44 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.9.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07f5e0fc8a6b3f2303f331b94504bbf754d85488f402d6f1dd7a6080f99afe56" +checksum = "b71a13df6ada0aafbf21a73bdfcdf9324cfa9df77d96b8446045be3cde61b42e" dependencies = [ "aws-smithy-async", + "aws-smithy-runtime-api-macros", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "pin-project-lite", "tokio", "tracing", "zeroize", ] +[[package]] +name = "aws-smithy-runtime-api-macros" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d7396fd9500589e62e460e987ecb671bad374934e55ec3b5f498cc7a8a8a7b7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "aws-smithy-types" -version = "1.3.2" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -936,18 +878,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.10" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" +checksum = "0ce02add1aa3677d022f8adf81dcbe3046a95f17a1b1e8979c145cd21d3d22b3" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.8" +version = "1.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b069d19bf01e46298eaedd7c6f283fe565a59263e53eebec945f3e6398f42390" +checksum = "2f4bbcaa9304ea40902d3d5f42a0428d1bd895a2b0f6999436fb279ffddc58ac" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -959,14 +901,14 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.4" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "itoa", @@ -975,8 +917,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rustversion", - "serde", + "serde_core", "sync_wrapper", "tower", "tower-layer", @@ -985,18 +926,17 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.2" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", - "rustversion", "sync_wrapper", "tower-layer", "tower-service", @@ -1026,61 +966,22 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" +checksum = "4d6867f1565b3aad85681f1015055b087fcfd840d6aeee6eee7f2da317603695" dependencies = [ "autocfg", "libm", "num-bigint", "num-integer", "num-traits", - "serde", -] - -[[package]] -name = "bindgen" -version = "0.72.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" -dependencies = [ - "bitflags 2.9.4", - "cexpr", - "clang-sys", - "itertools 0.13.0", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn 2.0.108", ] [[package]] name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" - -[[package]] -name = "bitvec" -version = "1.0.1" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" [[package]] name = "blake2" @@ -1088,20 +989,21 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest", + "digest 0.10.7", ] [[package]] name = "blake3" -version = "1.8.2" +version = "1.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", + "cpufeatures 0.3.0", ] [[package]] @@ -1113,24 +1015,32 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" +dependencies = [ + "hybrid-array", +] + [[package]] name = "bollard" -version = "0.19.3" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7646ee90964aa59e9f832a67182791396a19a5b1d76eb17599a8310a7e2e09" +checksum = "ee04c4c84f1f811b017f2fbb7dd8815c976e7ca98593de9c1e2afad0f636bff4" dependencies = [ "async-stream", "base64 0.22.1", - "bitflags 2.9.4", + "bitflags", "bollard-buildkit-proto", "bollard-stubs", "bytes", - "chrono", "futures-core", "futures-util", "hex", "home", - "http 1.3.1", + "http 1.4.0", "http-body-util", "hyper", "hyper-named-pipe", @@ -1140,17 +1050,16 @@ dependencies = [ "log", "num", "pin-project-lite", - "rand 0.9.2", + "rand 0.9.4", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_derive", "serde_json", - "serde_repr", "serde_urlencoded", "thiserror", + "time", "tokio", "tokio-stream", "tokio-util", @@ -1175,67 +1084,18 @@ dependencies = [ [[package]] name = "bollard-stubs" -version = "1.49.1-rc.28.4.0" +version = "1.52.1-rc.29.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5731fe885755e92beff1950774068e0cae67ea6ec7587381536fca84f1779623" +checksum = "0f0a8ca8799131c1837d1282c3f81f31e76ceb0ce426e04a7fe1ccee3287c066" dependencies = [ "base64 0.22.1", "bollard-buildkit-proto", "bytes", - "chrono", "prost", "serde", "serde_json", "serde_repr", - "serde_with", -] - -[[package]] -name = "bon" -version = "3.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2529c31017402be841eb45892278a6c21a000c0a17643af326c73a73f83f0fb" -dependencies = [ - "bon-macros", - "rustversion", -] - -[[package]] -name = "bon-macros" -version = "3.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82020dadcb845a345591863adb65d74fa8dc5c18a0b6d408470e13b7adc7005" -dependencies = [ - "darling", - "ident_case", - "prettyplease", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.108", -] - -[[package]] -name = "borsh" -version = "1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" -dependencies = [ - "borsh-derive", - "cfg_aliases", -] - -[[package]] -name = "borsh-derive" -version = "1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" -dependencies = [ - "once_cell", - "proc-macro-crate", - "proc-macro2", - "quote", - "syn 2.0.108", + "time", ] [[package]] @@ -1261,9 +1121,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" dependencies = [ "memchr", "serde", @@ -1271,31 +1131,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" - -[[package]] -name = "bytecheck" -version = "0.6.12" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23cdc57ce23ac53c931e88a43d06d070a6fd142f2617be5855eb75efc9beb1c2" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db406d29fbcd95542e92559bed4d8ad92636d1ca8b3b72ede10b4bcc010e659" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "byteorder" @@ -1305,9 +1143,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "bytes-utils" @@ -1319,15 +1157,6 @@ dependencies = [ "either", ] -[[package]] -name = "bzip2" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" -dependencies = [ - "bzip2-sys", -] - [[package]] name = "bzip2" version = "0.6.1" @@ -1337,16 +1166,6 @@ dependencies = [ "libbz2-rs-sys", ] -[[package]] -name = "bzip2-sys" -version = "0.1.13+1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" -dependencies = [ - "cc", - "pkg-config", -] - [[package]] name = "cast" version = "0.3.0" @@ -1355,9 +1174,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.38" +version = "1.2.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f41ae168f955c12fb8960b057d70d0ca153fb83182b57d86380443527be7e9" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" dependencies = [ "find-msvc-tools", "jobserver", @@ -1365,20 +1184,11 @@ dependencies = [ "shlex", ] -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" @@ -1387,17 +1197,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] -name = "chrono" -version = "0.4.42" +name = "chacha20" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" dependencies = [ - "iana-time-zone", + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", "js-sys", "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.0", + "windows-link", ] [[package]] @@ -1437,33 +1258,11 @@ dependencies = [ "half", ] -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading 0.8.9", -] - -[[package]] -name = "clap" -version = "2.34.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" -dependencies = [ - "bitflags 1.3.2", - "textwrap", - "unicode-width 0.1.14", -] - [[package]] name = "clap" -version = "4.5.50" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2cfd7bf8a6017ddaa4e32ffe7403d547790db06bd171c1c53926faab501623" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", "clap_derive", @@ -1471,9 +1270,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.50" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4c05b9e80c5ccd3a7ef080ad7b6ba7d6fc00a985b8b157197075677c82c7a0" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstream", "anstyle", @@ -1483,21 +1282,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.49" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "clap_lex" -version = "0.7.5" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" @@ -1510,53 +1309,66 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "c0f78a02292a74a88ac736019ab962ece0bc380e3f977bf72e376c5d78ff0678" dependencies = [ "cc", ] +[[package]] +name = "cmov" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" + [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "comfy-table" -version = "7.1.2" +version = "7.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" dependencies = [ - "strum 0.26.3", - "strum_macros 0.26.4", - "unicode-width 0.2.1", + "unicode-segmentation", + "unicode-width 0.2.2", ] [[package]] -name = "console" -version = "0.15.11" +name = "compression-codecs" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +checksum = "ce2548391e9c1929c21bf6aa2680af86fe4c1b33e6cea9ac1cfeec0bd11218cf" dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "windows-sys 0.59.0", + "bzip2", + "compression-core", + "flate2", + "liblzma", + "memchr", + "zstd", + "zstd-safe", ] +[[package]] +name = "compression-core" +version = "0.4.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789" + [[package]] name = "console" -version = "0.16.1" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" dependencies = [ "encode_unicode", "libc", - "once_cell", - "unicode-width 0.2.1", - "windows-sys 0.61.0", + "unicode-width 0.2.2", + "windows-sys 0.61.2", ] [[package]] @@ -1569,6 +1381,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "const-oid" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" + [[package]] name = "const-random" version = "0.1.18" @@ -1584,25 +1402,16 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "once_cell", "tiny-keccak", ] -[[package]] -name = "const_panic" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" -dependencies = [ - "typewit", -] - [[package]] name = "constant_time_eq" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" [[package]] name = "core-foundation" @@ -1621,29 +1430,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] -name = "core_extensions" -version = "1.5.4" +name = "cpufeatures" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42bb5e5d0269fd4f739ea6cedaf29c16d81c27a7ce7582008e90eb50dcd57003" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ - "core_extensions_proc_macros", + "libc", ] [[package]] -name = "core_extensions_proc_macros" -version = "1.5.4" +name = "cpufeatures" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "533d38ecd2709b7608fb8e18e4504deb99e9a72879e6aa66373a76d8dc4259ea" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] [[package]] -name = "cpufeatures" -version = "0.2.17" +name = "crc" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" dependencies = [ - "libc", + "crc-catalog", ] +[[package]] +name = "crc-catalog" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "217698eaf96b4a3f0bc4f3662aaa55bdf913cd54d7204591faa790070c6d0853" + [[package]] name = "crc32fast" version = "1.5.0" @@ -1655,19 +1473,21 @@ dependencies = [ [[package]] name = "criterion" -version = "0.7.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ + "alloca", "anes", "cast", "ciborium", - "clap 4.5.50", + "clap", "criterion-plot", "futures", "itertools 0.13.0", "num-traits", "oorandom", + "page_size", "plotters", "rayon", "regex", @@ -1680,23 +1500,14 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.6.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", "itertools 0.13.0", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1730,50 +1541,69 @@ checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77727bb15fa921304124b128af125e7e3b968275d1b108b379190264f4423710" +dependencies = [ + "hybrid-array", +] + [[package]] name = "csv" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ "csv-core", "itoa", "ryu", - "serde", + "serde_core", ] [[package]] name = "csv-core" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" dependencies = [ "memchr", ] [[package]] name = "ctor" -version = "0.6.1" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ffc71fcdcdb40d6f087edddf7f8f1f8f79e6cf922f555a9ee8779752d4819bd" +checksum = "83cf0d42651b16c6dfe68685716d18480d18a9c39c62d76e8cf3eb6ed5d8bcbf" dependencies = [ "ctor-proc-macro", "dtor", + "link-section", ] [[package]] name = "ctor-proc-macro" -version = "0.0.7" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a949c44fcacbbbb7ada007dc7acb34603dd97cd47de5d054f2b6493ecebb483" + +[[package]] +name = "ctutils" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52560adf09603e58c9a7ee1fe1dcb95a16927b17c127f0ac02d6e768a0e25bc1" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", +] [[package]] name = "cty" @@ -1783,9 +1613,9 @@ checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" [[package]] name = "darling" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ "darling_core", "darling_macro", @@ -1793,27 +1623,26 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" dependencies = [ - "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "darling_macro" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ "darling_core", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -1832,13 +1661,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-schema", "async-trait", "bytes", - "bzip2 0.6.1", + "bzip2", "chrono", "criterion", "ctor", @@ -1877,16 +1706,18 @@ dependencies = [ "flate2", "futures", "glob", + "indexmap 2.14.0", "insta", "itertools 0.14.0", + "liblzma", "log", "nix", "object_store", "parking_lot", "parquet", - "paste", - "rand 0.9.2", + "rand 0.9.4", "rand_distr", + "recursive", "regex", "rstest", "serde", @@ -1898,15 +1729,18 @@ dependencies = [ "tokio", "url", "uuid", - "xz2", "zstd", ] [[package]] name = "datafusion-benchmarks" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", + "async-trait", + "bytes", + "clap", + "criterion", "datafusion", "datafusion-common", "datafusion-proto", @@ -1917,19 +1751,19 @@ dependencies = [ "mimalloc", "object_store", "parquet", - "rand 0.9.2", + "rand 0.9.4", "regex", "serde", "serde_json", "snmalloc-rs", - "structopt", + "tempfile", "tokio", "tokio-util", ] [[package]] name = "datafusion-catalog" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -1952,10 +1786,11 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", + "chrono", "datafusion-catalog", "datafusion-common", "datafusion-datasource", @@ -1970,19 +1805,18 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "tokio", ] [[package]] name = "datafusion-cli" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", "aws-config", "aws-credential-types", "chrono", - "clap 4.5.50", + "clap", "ctor", "datafusion", "datafusion-common", @@ -1999,7 +1833,7 @@ dependencies = [ "regex", "rstest", "rustyline", - "testcontainers", + "serde_json", "testcontainers-modules", "tokio", "url", @@ -2007,34 +1841,35 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", - "apache-avro", "arrow", "arrow-ipc", + "arrow-schema", "chrono", + "criterion", + "foldhash 0.2.0", "half", - "hashbrown 0.14.5", + "hashbrown 0.17.0", "hex", - "indexmap 2.12.0", + "indexmap 2.14.0", "insta", + "itertools 0.14.0", "libc", "log", "object_store", "parquet", - "paste", - "pyo3", - "rand 0.9.2", + "rand 0.9.4", "recursive", "sqlparser", "tokio", + "uuid", "web-time", ] [[package]] name = "datafusion-common-runtime" -version = "50.3.0" +version = "53.1.0" dependencies = [ "futures", "log", @@ -2043,13 +1878,13 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.6.1", + "bzip2", "chrono", "criterion", "datafusion-common", @@ -2064,21 +1899,23 @@ dependencies = [ "flate2", "futures", "glob", + "insta", "itertools 0.14.0", + "liblzma", "log", "object_store", - "rand 0.9.2", + "parking_lot", + "rand 0.9.4", "tempfile", "tokio", "tokio-util", "url", - "xz2", "zstd", ] [[package]] name = "datafusion-datasource-arrow" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-ipc", @@ -2101,26 +1938,24 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "apache-avro", "arrow", + "arrow-avro", "async-trait", "bytes", "datafusion-common", "datafusion-datasource", - "datafusion-physical-expr-common", + "datafusion-physical-expr-adapter", "datafusion-physical-plan", "datafusion-session", "futures", - "num-traits", "object_store", - "serde_json", ] [[package]] name = "datafusion-datasource-csv" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2141,7 +1976,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2156,23 +1991,28 @@ dependencies = [ "datafusion-session", "futures", "object_store", + "serde_json", "tokio", + "tokio-stream", ] [[package]] name = "datafusion-datasource-parquet" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", "bytes", "chrono", + "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate-common", + "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", @@ -2185,16 +2025,17 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "tempfile", "tokio", ] [[package]] name = "datafusion-doc" -version = "50.3.0" +version = "53.1.0" [[package]] name = "datafusion-examples" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-flight", @@ -2204,18 +2045,25 @@ dependencies = [ "bytes", "dashmap", "datafusion", - "datafusion-ffi", + "datafusion-common", + "datafusion-expr", "datafusion-physical-expr-adapter", "datafusion-proto", + "datafusion-sql", "env_logger", "futures", + "insta", "log", "mimalloc", "nix", + "nom", "object_store", "prost", - "rand 0.9.2", + "rand 0.9.4", + "serde", "serde_json", + "strum", + "strum_macros", "tempfile", "test-utils", "tokio", @@ -2228,30 +2076,33 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", + "arrow-buffer", "async-trait", "chrono", "dashmap", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr-common", "futures", "insta", "log", "object_store", "parking_lot", "parquet", - "rand 0.9.2", + "rand 0.9.4", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", + "arrow-schema", "async-trait", "chrono", "ctor", @@ -2262,10 +2113,9 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.12.0", + "indexmap 2.14.0", "insta", "itertools 0.14.0", - "paste", "recursive", "serde_json", "sqlparser", @@ -2273,40 +2123,54 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.12.0", + "indexmap 2.14.0", + "insta", "itertools 0.14.0", - "paste", ] [[package]] name = "datafusion-ffi" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "abi_stable", "arrow", "arrow-schema", "async-ffi", "async-trait", "datafusion", + "datafusion-catalog", "datafusion-common", + "datafusion-datasource", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-optimizer", + "datafusion-physical-plan", "datafusion-proto", "datafusion-proto-common", + "datafusion-session", "doc-comment", "futures", + "libloading", "log", "prost", "semver", + "stabby", "tokio", ] [[package]] name = "datafusion-functions" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -2314,6 +2178,7 @@ dependencies = [ "blake2", "blake3", "chrono", + "chrono-tz", "criterion", "ctor", "datafusion-common", @@ -2322,25 +2187,25 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-macros", + "datafusion-physical-expr-common", "env_logger", "hex", "itertools 0.14.0", "log", - "md-5", + "md-5 0.11.0", + "memchr", "num-traits", - "rand 0.9.2", + "rand 0.9.4", "regex", "sha2", "tokio", - "unicode-segmentation", "uuid", ] [[package]] name = "datafusion-functions-aggregate" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2351,28 +2216,28 @@ dependencies = [ "datafusion-macros", "datafusion-physical-expr", "datafusion-physical-expr-common", + "foldhash 0.2.0", "half", "log", - "paste", - "rand 0.9.2", + "num-traits", + "rand 0.9.4", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand 0.9.2", + "rand 0.9.4", ] [[package]] name = "datafusion-functions-nested" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "arrow-ord", @@ -2386,16 +2251,19 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", + "hashbrown 0.17.0", "itertools 0.14.0", + "itoa", "log", - "paste", - "rand 0.9.2", + "memchr", + "rand 0.9.4", ] [[package]] name = "datafusion-functions-table" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2404,14 +2272,14 @@ dependencies = [ "datafusion-expr", "datafusion-physical-plan", "parking_lot", - "paste", ] [[package]] name = "datafusion-functions-window" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", + "criterion", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -2420,12 +2288,11 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "log", - "paste", ] [[package]] name = "datafusion-functions-window-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2433,16 +2300,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "50.3.0" +version = "53.1.0" dependencies = [ "datafusion-doc", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "datafusion-optimizer" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", @@ -2458,7 +2325,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.12.0", + "indexmap 2.14.0", "insta", "itertools 0.14.0", "log", @@ -2469,9 +2336,8 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2481,20 +2347,21 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown 0.14.5", - "indexmap 2.12.0", + "hashbrown 0.17.0", + "indexmap 2.14.0", "insta", "itertools 0.14.0", "parking_lot", - "paste", - "petgraph 0.8.3", - "rand 0.9.2", + "petgraph", + "rand 0.9.4", + "recursive", "rstest", + "tokio", ] [[package]] name = "datafusion-physical-expr-adapter" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2507,19 +2374,24 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", "arrow", + "chrono", + "criterion", "datafusion-common", "datafusion-expr-common", - "hashbrown 0.14.5", + "hashbrown 0.17.0", + "indexmap 2.14.0", "itertools 0.14.0", + "parking_lot", + "pin-project", + "rand 0.9.4", ] [[package]] name = "datafusion-physical-optimizer" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2527,6 +2399,7 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-functions", + "datafusion-functions-window", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", @@ -2539,19 +2412,19 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "50.3.0" +version = "53.1.0" dependencies = [ - "ahash 0.8.12", "arrow", + "arrow-data", "arrow-ord", "arrow-schema", "async-trait", - "chrono", "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-functions-window", @@ -2560,14 +2433,15 @@ dependencies = [ "datafusion-physical-expr-common", "futures", "half", - "hashbrown 0.14.5", - "indexmap 2.12.0", + "hashbrown 0.17.0", + "indexmap 2.14.0", "insta", "itertools 0.14.0", "log", + "num-traits", "parking_lot", "pin-project-lite", - "rand 0.9.2", + "rand 0.9.4", "rstest", "rstest_reuse", "tokio", @@ -2575,9 +2449,10 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", + "async-trait", "chrono", "datafusion", "datafusion-catalog", @@ -2601,7 +2476,7 @@ dependencies = [ "datafusion-proto-common", "doc-comment", "object_store", - "pbjson", + "pbjson 0.9.0", "pretty_assertions", "prost", "serde", @@ -2611,19 +2486,19 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", - "pbjson", + "pbjson 0.9.0", "prost", "serde", ] [[package]] name = "datafusion-pruning" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "datafusion-common", @@ -2641,7 +2516,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "50.3.0" +version = "53.1.0" dependencies = [ "async-trait", "datafusion-common", @@ -2653,27 +2528,35 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "bigdecimal", "chrono", "crc32fast", "criterion", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", + "datafusion-functions-nested", "log", - "rand 0.9.2", - "sha1", + "num-traits", + "percent-encoding", + "rand 0.9.4", + "serde_json", + "sha1 0.11.0", + "sha2", "url", ] [[package]] name = "datafusion-sql" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "bigdecimal", @@ -2686,11 +2569,10 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.12.0", + "indexmap 2.14.0", "insta", "itertools 0.14.0", "log", - "paste", "recursive", "regex", "rstest", @@ -2699,14 +2581,14 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "50.3.0" +version = "53.1.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.50", + "clap", "datafusion", "datafusion-spark", "datafusion-substrait", @@ -2717,14 +2599,12 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "postgres-protocol", "postgres-types", "regex", - "rust_decimal", + "serde_json", "sqllogictest", "sqlparser", "tempfile", - "testcontainers", "testcontainers-modules", "thiserror", "tokio", @@ -2733,7 +2613,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "50.3.0" +version = "53.1.0" dependencies = [ "async-recursion", "async-trait", @@ -2750,13 +2630,13 @@ dependencies = [ "substrait", "tokio", "url", - "uuid", ] [[package]] name = "datafusion-wasmtest" -version = "50.3.0" +version = "53.1.0" dependencies = [ + "bytes", "chrono", "console_error_panic_hook", "datafusion", @@ -2766,6 +2646,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", + "futures", "getrandom 0.3.4", "object_store", "tokio", @@ -2776,12 +2657,12 @@ dependencies = [ [[package]] name = "deranged" -version = "0.5.3" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ "powerfmt", - "serde", + "serde_core", ] [[package]] @@ -2796,11 +2677,23 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", - "crypto-common", + "block-buffer 0.10.4", + "crypto-common 0.1.7", "subtle", ] +[[package]] +name = "digest" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +dependencies = [ + "block-buffer 0.12.0", + "const-oid", + "crypto-common 0.2.1", + "ctutils", +] + [[package]] name = "dirs" version = "6.0.0" @@ -2819,7 +2712,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.61.0", + "windows-sys 0.60.2", ] [[package]] @@ -2830,14 +2723,14 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "doc-comment" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" [[package]] name = "docker_credential" @@ -2852,18 +2745,18 @@ dependencies = [ [[package]] name = "dtor" -version = "0.1.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "404d02eeb088a82cfd873006cb713fe411306c7d182c344905e101fb1167d301" +checksum = "edf234dd1594d6dd434a8fb8cada51ddbbc593e40e4a01556a0b31c62da2775b" dependencies = [ "dtor-proc-macro", ] [[package]] name = "dtor-proc-macro" -version = "0.0.6" +version = "0.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5" +checksum = "2647271c92754afcb174e758003cfd1cbf1e43e5a7853d7b1813e63e19e39a73" [[package]] name = "dunce" @@ -2886,7 +2779,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -2903,35 +2796,35 @@ checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "endian-type" -version = "0.1.2" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +checksum = "869b0adbda23651a9c5c0c3d270aac9fcb52e8622a8f2b17e57802d7791962f2" [[package]] name = "enum-ordinalize" -version = "4.3.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" dependencies = [ "enum-ordinalize-derive", ] [[package]] name = "enum-ordinalize-derive" -version = "4.3.1" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "env_filter" -version = "0.1.3" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -2939,9 +2832,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.8" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "anstream", "anstyle", @@ -2963,7 +2856,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.0", + "windows-sys 0.52.0", ] [[package]] @@ -2980,13 +2873,12 @@ checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" [[package]] name = "etcetera" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c7b13d0780cb82722fd59f6f57f925e143427e4a75313a6c77243bf5326ae6" +checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96" dependencies = [ "cfg-if", - "home", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2997,26 +2889,25 @@ checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" [[package]] name = "fastrand" -version = "2.3.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" [[package]] -name = "fd-lock" -version = "4.0.4" +name = "ferroid" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" +checksum = "ee93edf3c501f0035bbeffeccfed0b79e14c311f12195ec0e661e114a0f60da4" dependencies = [ - "cfg-if", - "rustix", - "windows-sys 0.59.0", + "portable-atomic", + "rand 0.10.1", + "web-time", ] [[package]] name = "ffi_example_table_provider" version = "0.1.0" dependencies = [ - "abi_stable", "arrow", "datafusion", "datafusion-ffi", @@ -3027,7 +2918,6 @@ dependencies = [ name = "ffi_module_interface" version = "0.1.0" dependencies = [ - "abi_stable", "datafusion-ffi", ] @@ -3035,30 +2925,29 @@ dependencies = [ name = "ffi_module_loader" version = "0.1.0" dependencies = [ - "abi_stable", "datafusion", "datafusion-ffi", "ffi_module_interface", + "libloading", "tokio", ] [[package]] name = "filetime" -version = "0.2.26" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", "libredox", - "windows-sys 0.60.2", ] [[package]] name = "find-msvc-tools" -version = "0.1.2" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ced73b1dacfc750a6db6c0a0c3a3853c8b41997e2e2c563dc90804ae6867959" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "fixedbitset" @@ -3068,23 +2957,23 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" -version = "25.2.10" +version = "25.12.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" dependencies = [ - "bitflags 2.9.4", + "bitflags", "rustc_version", ] [[package]] name = "flate2" -version = "1.1.4" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", + "zlib-rs", ] [[package]] @@ -3099,6 +2988,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -3110,9 +3005,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.1.2" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f150ffc8782f35521cec2b23727707cb4045706ba3c854e86bef66b3a8cdbd" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" dependencies = [ "autocfg", ] @@ -3123,17 +3018,11 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -3146,9 +3035,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -3156,15 +3045,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -3173,32 +3062,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -3208,9 +3097,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -3220,7 +3109,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -3228,7 +3116,7 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3236,24 +3124,15 @@ dependencies = [ name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] [[package]] -name = "generational-arena" -version = "0.2.9" +name = "generic-array" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877e94aff08e743b651baaea359664321055749b398adff8740a7399af7796e7" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -3261,14 +3140,14 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -3281,11 +3160,25 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "rand_core 0.10.1", + "wasip2", + "wasip3", +] + [[package]] name = "glob" version = "0.3.3" @@ -3294,9 +3187,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "globset" -version = "0.4.16" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a1028dfc5f5df5da8a56a73e6c153c9a9708ec57232470703592a3f18e49f5" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" dependencies = [ "aho-corasick", "bstr", @@ -3307,17 +3200,17 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.3.1", - "indexmap 2.12.0", + "http 1.4.0", + "indexmap 2.14.0", "slab", "tokio", "tokio-util", @@ -3333,6 +3226,8 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand 0.9.4", + "rand_distr", "zerocopy", ] @@ -3341,19 +3236,12 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash 0.7.8", -] [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash 0.8.12", - "allocator-api2", -] [[package]] name = "hashbrown" @@ -3361,24 +3249,29 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "allocator-api2", - "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", +] [[package]] -name = "heck" -version = "0.3.3" +name = "hashbrown" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" dependencies = [ - "unicode-segmentation", + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -3395,20 +3288,20 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hmac" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" dependencies = [ - "digest", + "digest 0.11.2", ] [[package]] name = "home" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3424,12 +3317,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -3451,7 +3343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.3.1", + "http 1.4.0", ] [[package]] @@ -3462,7 +3354,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "pin-project-lite", ] @@ -3485,24 +3377,32 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" +[[package]] +name = "hybrid-array" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" +dependencies = [ + "typenum", +] + [[package]] name = "hyper" -version = "1.7.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" dependencies = [ "atomic-waker", "bytes", "futures-channel", "futures-core", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "httparse", "httpdate", "itoa", "pin-project-lite", - "pin-utils", "smallvec", "tokio", "want", @@ -3525,16 +3425,15 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.7" +version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-util", "rustls", "rustls-native-certs", - "rustls-pki-types", "tokio", "tokio-rustls", "tower-service", @@ -3555,16 +3454,15 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.17" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", - "futures-core", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "hyper", "ipnet", @@ -3594,9 +3492,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.64" +version = "0.1.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -3604,7 +3502,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.0", + "windows-core", ] [[package]] @@ -3618,12 +3516,13 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.0.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" dependencies = [ "displaydoc", "potential_utf", + "utf8_iter", "yoke", "zerofrom", "zerovec", @@ -3631,9 +3530,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.0.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", @@ -3644,11 +3543,10 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.0.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" dependencies = [ - "displaydoc", "icu_collections", "icu_normalizer_data", "icu_properties", @@ -3659,42 +3557,38 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.0.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" [[package]] name = "icu_properties" -version = "2.0.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" dependencies = [ - "displaydoc", "icu_collections", "icu_locale_core", "icu_properties_data", "icu_provider", - "potential_utf", "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "2.0.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" [[package]] name = "icu_provider" -version = "2.0.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", - "stable_deref_trait", - "tinystr", "writeable", "yoke", "zerofrom", @@ -3702,6 +3596,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -3742,47 +3642,42 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.0", + "hashbrown 0.17.0", "serde", "serde_core", ] [[package]] name = "indicatif" -version = "0.18.0" +version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a646d946d06bedbbc4cac4c218acf4bbf2d87757a784857025f4d447e4e1cd" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console 0.16.1", + "console", "portable-atomic", - "unicode-width 0.2.1", + "unicode-width 0.2.2", "unit-prefix", "web-time", ] -[[package]] -name = "indoc" -version = "2.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" - [[package]] name = "insta" -version = "1.43.2" +version = "1.47.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fdb647ebde000f43b5b53f773c30cf9b0cb4300453208713fa38b2c70935a0" +checksum = "7b4a6248eb93a4401ed2f37dfe8ea592d3cf05b7cf4f8efa867b6895af7e094e" dependencies = [ - "console 0.15.11", + "console", "globset", "once_cell", "regex", "serde", "similar", + "tempfile", "walkdir", ] @@ -3805,15 +3700,15 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.8" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" dependencies = [ "memchr", "serde", @@ -3821,9 +3716,9 @@ dependencies = [ [[package]] name = "is_terminal_polyfill" -version = "1.70.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" [[package]] name = "itertools" @@ -3845,32 +3740,32 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" -version = "0.2.15" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" +checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" dependencies = [ "jiff-static", "log", "portable-atomic", "portable-atomic-util", - "serde", + "serde_core", ] [[package]] name = "jiff-static" -version = "0.2.15" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" +checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -3885,10 +3780,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -3899,6 +3796,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "lexical-core" version = "1.0.6" @@ -3958,112 +3861,118 @@ dependencies = [ [[package]] name = "libbz2-rs-sys" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" +checksum = "b3a6a8c165077efc8f3a971534c50ea6a1a18b329ef4a66e897a7e3a1494565f" [[package]] name = "libc" -version = "0.2.177" +version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "libloading" -version = "0.7.4" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" dependencies = [ "cfg-if", - "winapi", + "windows-link", ] [[package]] -name = "libloading" -version = "0.8.9" +name = "liblzma" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" dependencies = [ - "cfg-if", - "windows-link 0.2.0", + "liblzma-sys", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a60851d15cd8c5346eca4ab8babff585be2ae4bc8097c067291d3ffe2add3b6" +dependencies = [ + "cc", + "libc", + "pkg-config", ] [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libmimalloc-sys" -version = "0.1.44" +version = "0.1.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "667f4fec20f29dfc6bc7357c582d91796c169ad7e2fce709468aefeb2c099870" +checksum = "2d1eacfa31c33ec25e873c136ba5669f00f9866d0688bea7be4d3f7e43067df6" dependencies = [ "cc", "cty", - "libc", ] [[package]] name = "libredox" -version = "0.1.10" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" dependencies = [ - "bitflags 2.9.4", + "bitflags", "libc", - "redox_syscall", + "plain", + "redox_syscall 0.7.4", ] [[package]] name = "libtest-mimic" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +checksum = "14e6ba06f0ade6e504aff834d7c34298e5155c6baca353cc6a4aaff2f9fd7f33" dependencies = [ "anstream", "anstyle", - "clap 4.5.50", + "clap", "escape8259", ] [[package]] -name = "libz-rs-sys" -version = "0.5.2" +name = "link-section" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" -dependencies = [ - "zlib-rs", -] +checksum = "b685d66585d646efe09fec763d796c291049c8b6bf84e04954bffc8748341f0d" [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" [[package]] name = "lock_api" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", "scopeguard", ] [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -4073,24 +3982,13 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "lz4_flex" -version = "0.11.5" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" +checksum = "db9a0d582c2874f68138a16ce1867e0ffde6c0bb0a0df85e1f36d04146db488a" dependencies = [ "twox-hash", ] -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "matchit" version = "0.8.4" @@ -4104,29 +4002,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ "cfg-if", - "digest", + "digest 0.10.7", ] [[package]] -name = "memchr" -version = "2.7.5" +name = "md-5" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98" +dependencies = [ + "cfg-if", + "digest 0.11.2", +] [[package]] -name = "memoffset" -version = "0.9.1" +name = "memchr" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "mimalloc" -version = "0.1.48" +version = "0.1.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1ee66a4b64c74f4ef288bcbb9192ad9c3feaad75193129ac8509af543894fd8" +checksum = "b3627c4272df786b9260cabaa46aec1d59c93ede723d4c3ef646c503816b0640" dependencies = [ "libmimalloc-sys", ] @@ -4139,20 +4038,14 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "minicov" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27fe9f1cc3c22e1687f9446c2083c4c5fc7f0bcf1c7a86bdbded14985895b4b" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" dependencies = [ "cc", "walkdir", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.8.9" @@ -4165,13 +4058,13 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.4" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" dependencies = [ "libc", - "wasi", - "windows-sys 0.59.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.61.2", ] [[package]] @@ -4191,11 +4084,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.30.1" +version = "0.31.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" dependencies = [ - "bitflags 2.9.4", + "bitflags", "cfg-if", "cfg_aliases", "libc", @@ -4203,30 +4096,29 @@ dependencies = [ [[package]] name = "nom" -version = "7.1.3" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" dependencies = [ "memchr", - "minimal-lexical", ] [[package]] name = "ntapi" -version = "0.4.1" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" dependencies = [ "winapi", ] [[package]] name = "nu-ansi-term" -version = "0.50.1" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -4251,7 +4143,6 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", - "serde", ] [[package]] @@ -4265,9 +4156,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num-integer" @@ -4312,48 +4203,68 @@ dependencies = [ [[package]] name = "objc2-core-foundation" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "bitflags 2.9.4", + "bitflags", ] [[package]] name = "objc2-io-kit" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +checksum = "33fafba39597d6dc1fb709123dfa8289d39406734be322956a69f0931c73bb15" dependencies = [ "libc", "objc2-core-foundation", ] +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + [[package]] name = "object_store" -version = "0.12.4" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" +checksum = "622acbc9100d3c10e2ee15804b0caa40e55c933d5aa53814cd520805b7958a49" dependencies = [ "async-trait", "base64 0.22.1", "bytes", "chrono", "form_urlencoded", - "futures", - "http 1.3.1", + "futures-channel", + "futures-core", + "futures-util", + "http 1.4.0", "http-body-util", "humantime", "hyper", "itertools 0.14.0", - "md-5", + "md-5 0.10.6", "parking_lot", "percent-encoding", "quick-xml", - "rand 0.9.2", + "rand 0.10.1", "reqwest", "ring", - "rustls-pemfile", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", @@ -4368,15 +4279,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" -version = "1.70.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "oorandom" @@ -4386,9 +4297,9 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "option-ext" @@ -4413,15 +4324,25 @@ checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "owo-colors" -version = "4.2.2" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" + +[[package]] +name = "page_size" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dd4f4a2c8405440fd0462561f0e5806bd0f77e86f51c761481bdd4018b545e" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] [[package]] name = "parking_lot" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" dependencies = [ "lock_api", "parking_lot_core", @@ -4429,27 +4350,26 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.11" +version = "0.9.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "parquet" -version = "57.0.0" +version = "58.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0f31027ef1af7549f7cec603a9a21dce706d3f8d7c2060a68f43c1773be95a" +checksum = "7d3f9f2205199603564127932b89695f52b62322f541d0fc7179d57c2e1c9877" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-ipc", "arrow-schema", @@ -4461,7 +4381,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "lz4_flex", "num-bigint", "num-integer", @@ -4500,7 +4420,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -4519,13 +4439,35 @@ dependencies = [ "serde", ] +[[package]] +name = "pbjson" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8edd1efdd8ab23ba9cb9ace3d9987a72663d5d7c9f74fa00b51d6213645cf6c" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "pbjson-build" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" dependencies = [ - "heck 0.5.0", + "heck", + "itertools 0.14.0", + "prost", + "prost-types", +] + +[[package]] +name = "pbjson-build" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ed4d5c6ae95e08ac768883c8401cf0e8deb4e6e1d6a4e1fd3d2ec4f0ec63200" +dependencies = [ + "heck", "itertools 0.14.0", "prost", "prost-types", @@ -4539,8 +4481,8 @@ checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" dependencies = [ "bytes", "chrono", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "prost", "prost-build", "serde", @@ -4552,16 +4494,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "petgraph" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" -dependencies = [ - "fixedbitset", - "indexmap 2.12.0", -] - [[package]] name = "petgraph" version = "0.8.3" @@ -4570,7 +4502,7 @@ checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.12.0", + "indexmap 2.14.0", "serde", ] @@ -4613,29 +4545,29 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -4645,9 +4577,15 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.32" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "plain" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" [[package]] name = "plotters" @@ -4679,54 +4617,54 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" dependencies = [ "portable-atomic", ] [[package]] name = "postgres-derive" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56df96f5394370d1b20e49de146f9e6c25aa9ae750f449c9d665eafecb3ccae6" +checksum = "ca1dad89d9ffdbf78502fde418eeede499b87772d88be780478f7f76dc8d471f" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "postgres-protocol" -version = "0.6.9" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4" +checksum = "56201207dac53e2f38e848e31b4b91616a6bb6e0c7205b77718994a7f49e70fc" dependencies = [ "base64 0.22.1", "byteorder", "bytes", "fallible-iterator", "hmac", - "md-5", + "md-5 0.11.0", "memchr", - "rand 0.9.2", + "rand 0.10.1", "sha2", "stringprep", ] [[package]] name = "postgres-types" -version = "0.2.11" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4605b7c057056dd35baeb6ac0c0338e4975b1f2bef0f65da953285eb007095" +checksum = "8dc729a129e682e8d24170cd30ae1aa01b336b096cbb56df6d534ffec133d186" dependencies = [ "bytes", "chrono", @@ -4737,9 +4675,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.3" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ "zerovec", ] @@ -4776,56 +4714,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "proc-macro-crate" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ "toml_edit", ] [[package]] -name = "proc-macro-error" -version = "1.0.4" +name = "proc-macro2" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", + "unicode-ident", ] [[package]] -name = "proc-macro-error-attr" -version = "1.0.4" +name = "prost" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - -[[package]] -name = "proc-macro2" -version = "1.0.101" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "prost" -version = "0.14.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -4833,42 +4747,41 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ - "heck 0.5.0", + "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", - "petgraph 0.7.1", + "petgraph", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.108", + "syn 2.0.117", "tempfile", ] [[package]] name = "prost-derive" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "prost-types" -version = "0.14.1" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -4884,105 +4797,19 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.26" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +checksum = "645dbe486e346d9b5de3ef16ede18c26e6c70ad97418f4874b8b1889d6e761ea" dependencies = [ + "ar_archive_writer", "cc", ] -[[package]] -name = "ptr_meta" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "pyo3" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" -dependencies = [ - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" -dependencies = [ - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.108", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.108", -] - -[[package]] -name = "quad-rand" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" - [[package]] name = "quick-xml" -version = "0.38.3" +version = "0.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" +checksum = "958f21e8e7ceb5a1aa7fa87fab28e7c75976e0bfe7e23ff069e0a260f894067d" dependencies = [ "memchr", "serde", @@ -5010,14 +4837,14 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.2", + "rand 0.9.4", "ring", "rustc-hash", "rustls", @@ -5040,14 +4867,14 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.52.0", ] [[package]] name = "quote" -version = "1.0.41" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -5059,16 +4886,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] -name = "radium" -version = "0.7.0" +name = "r-efi" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "radix_trie" -version = "0.2.1" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +checksum = "3b4431027dcd37fc2a73ef740b5f233aa805897935b8bce0195e41bbf9a3289a" dependencies = [ "endian-type", "nibble_vec", @@ -5076,9 +4903,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -5087,12 +4914,23 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_core 0.9.5", +] + +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", ] [[package]] @@ -5112,7 +4950,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -5121,18 +4959,24 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", ] [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "rand_distr" version = "0.5.1" @@ -5140,14 +4984,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.2", + "rand 0.9.4", ] [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -5180,16 +5024,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "redox_syscall" -version = "0.5.17" +version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.9.4", + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f450ad9c3b1da563fb6948a8e0fb0fb9269711c9c73d9ea1de5058c79c8d643a" +dependencies = [ + "bitflags", ] [[package]] @@ -5198,36 +5051,36 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "libredox", "thiserror", ] [[package]] name = "ref-cast" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" dependencies = [ "ref-cast-impl", ] [[package]] name = "ref-cast-impl" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -5237,9 +5090,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -5248,23 +5101,23 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.7" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943f41321c63ef1c92fd763bfe054d2668f7f225a5c29f0105903dc2fc04ba30" +checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.6" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "regress" -version = "0.10.4" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145bb27393fe455dd64d6cbc8d059adfa392590a45eadf079c01b11857e7b010" +checksum = "2057b2325e68a893284d1538021ab90279adac1139957ca2a74426c6f118fb48" dependencies = [ - "hashbrown 0.15.5", + "hashbrown 0.16.1", "memchr", ] @@ -5274,36 +5127,18 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" -[[package]] -name = "rend" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fe3824f5629716b1589be05dacd749f6aa084c87e00e016714a8cdfccc997c" -dependencies = [ - "bytecheck", -] - -[[package]] -name = "repr_offset" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb1070755bd29dffc19d0971cab794e607839ba2ef4b69a9e6fbc8733c1b72ea" -dependencies = [ - "tstr", -] - [[package]] name = "reqwest" -version = "0.12.23" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", @@ -5342,41 +5177,12 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.16", + "getrandom 0.2.17", "libc", "untrusted", "windows-sys 0.52.0", ] -[[package]] -name = "rkyv" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" -dependencies = [ - "bitvec", - "bytecheck", - "bytes", - "hashbrown 0.12.3", - "ptr_meta", - "rend", - "rkyv_derive", - "seahash", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "rstest" version = "0.26.1" @@ -5402,7 +5208,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.108", + "syn 2.0.117", "unicode-ident", ] @@ -5413,32 +5219,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", - "rand 0.8.5", - "syn 2.0.108", -] - -[[package]] -name = "rust_decimal" -version = "1.38.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8975fc98059f365204d635119cf9c5a60ae67b841ed49b5422a9a7e56cdfac0" -dependencies = [ - "arrayvec", - "borsh", - "bytes", - "num-traits", - "postgres-types", - "rand 0.8.5", - "rkyv", - "serde", - "serde_json", + "rand 0.8.6", + "syn 2.0.117", ] [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustc_version" @@ -5451,22 +5240,22 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ - "bitflags 2.9.4", + "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.0", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.23.32" +version = "0.23.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" dependencies = [ "aws-lc-rs", "log", @@ -5480,9 +5269,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -5490,20 +5279,11 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" -version = "1.12.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ "web-time", "zeroize", @@ -5511,9 +5291,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.6" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -5529,14 +5309,13 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "rustyline" -version = "17.0.2" +version = "18.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e902948a25149d50edc1a8e0141aad50f54e22ba83ff988cf8f7c9ef07f50564" +checksum = "4a990b25f351b25139ddc7f21ee3f6f56f86d6846b74ac8fad3a719a287cd4a0" dependencies = [ - "bitflags 2.9.4", + "bitflags", "cfg-if", "clipboard-win", - "fd-lock", "home", "libc", "log", @@ -5544,16 +5323,16 @@ dependencies = [ "nix", "radix_trie", "unicode-segmentation", - "unicode-width 0.2.1", + "unicode-width 0.2.2", "utf8parse", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "same-file" @@ -5566,11 +5345,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ - "windows-sys 0.61.0", + "windows-sys 0.61.2", ] [[package]] @@ -5599,9 +5378,9 @@ dependencies = [ [[package]] name = "schemars" -version = "1.0.4" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" dependencies = [ "dyn-clone", "ref-cast", @@ -5618,7 +5397,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -5627,19 +5406,13 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "seahash" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" - [[package]] name = "security-framework" -version = "3.5.0" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc198e42d9b7510827939c9a15f5062a0c913f3371d765977e586d2fe6c16f4a" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.9.4", + "bitflags", "core-foundation", "core-foundation-sys", "libc", @@ -5648,9 +5421,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.15.0" +version = "2.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" dependencies = [ "core-foundation-sys", "libc", @@ -5658,9 +5431,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" dependencies = [ "serde", "serde_core", @@ -5682,16 +5455,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde_bytes" -version = "0.11.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" -dependencies = [ - "serde", - "serde_core", -] - [[package]] name = "serde_core" version = "1.0.228" @@ -5709,7 +5472,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -5720,20 +5483,21 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ + "indexmap 2.14.0", "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -5744,19 +5508,19 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "serde_tokenstream" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64060d864397305347a78851c51588fd283767e7e7589829e8121d65512340f1" +checksum = "d7c49585c52c01f13c5c2ebb333f14f6885d76daa768d8a037d28017ec538c69" dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -5773,19 +5537,18 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.14.1" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c522100790450cf78eeac1507263d0a350d4d5b30df0c8e1fe051a10c22b376e" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.12.0", + "indexmap 2.14.0", "schemars 0.9.0", - "schemars 1.0.4", - "serde", - "serde_derive", + "schemars 1.2.1", + "serde_core", "serde_json", "serde_with_macros", "time", @@ -5793,14 +5556,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.14.1" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327ada00f7d64abaac1e55a6911e90cf665aa051b9a561c7006c157f4633135e" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -5809,7 +5572,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.14.0", "itoa", "ryu", "serde", @@ -5823,21 +5586,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", - "digest", + "cpufeatures 0.2.17", + "digest 0.10.7", +] + +[[package]] +name = "sha1" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.2", ] [[package]] name = "sha2" -version = "0.10.9" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", - "cpufeatures", - "digest", + "cpufeatures 0.3.0", + "digest 0.11.2", ] +[[package]] +name = "sha2-const-stable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f179d4e11094a893b82fff208f74d448a7512f99f5a0acbd5c679b705f83ed9" + [[package]] name = "sharded-slab" version = "0.1.7" @@ -5855,18 +5635,19 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.6" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "simdutf8" @@ -5882,15 +5663,15 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" @@ -5906,37 +5687,37 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "snmalloc-rs" -version = "0.3.8" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb317153089fdfa4d8a2eec059d40a5a23c3bde43995ea23b19121c3f621e74a" +checksum = "530a04ae687609072d0edd38866406fbbcd23d2f716791437e312ec4d64a355a" dependencies = [ "snmalloc-sys", ] [[package]] name = "snmalloc-sys" -version = "0.3.8" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "065fea53d32bb77bc36cca466cb191f2e5216ebfd0ed360b1d64889ee6e559ea" +checksum = "a96cbeb16d6bcc5979f80ec907582a886b7fb3b9a707678b63dd93a10d8ee858" dependencies = [ "cmake", ] [[package]] name = "socket2" -version = "0.6.0" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "sqllogictest" -version = "0.28.4" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3566426f72a13e393aa34ca3d542c5b0eb86da4c0db137ee9b5cfccc6179e52d" +checksum = "d03b2262a244037b0b510edbd25a8e6c9fb8d73ee0237fc6cc95a54c16f94a82" dependencies = [ "async-trait", "educe", @@ -5946,9 +5727,9 @@ dependencies = [ "humantime", "itertools 0.13.0", "libtest-mimic", - "md-5", + "md-5 0.10.6", "owo-colors", - "rand 0.8.5", + "rand 0.8.6", "regex", "similar", "subst", @@ -5959,9 +5740,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.59.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" +checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", "recursive", @@ -5970,32 +5751,67 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.3.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" +checksum = "a6dd45d8fc1c79299bfbb7190e42ccbbdf6a5f52e4a6ad98d92357ea965bd289" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", +] + +[[package]] +name = "stabby" +version = "72.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976399a0c48ea769ef7f5dc303bb88240ab8d84008647a6b2303eced3dab3945" +dependencies = [ + "rustversion", + "stabby-abi", +] + +[[package]] +name = "stabby-abi" +version = "72.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7b54832a9a1f92a0e55e74a5c0332744426edc515bb3fbad82f10b874a87f0d" +dependencies = [ + "rustc_version", + "rustversion", + "sha2-const-stable", + "stabby-macros", +] + +[[package]] +name = "stabby-macros" +version = "72.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a768b1e51e4dbfa4fa52ae5c01241c0a41e2938fdffbb84add0c8238092f9091" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "rand 0.8.6", + "syn 1.0.109", ] [[package]] name = "stable_deref_trait" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "stacker" -version = "0.1.21" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +checksum = "640c8cdd92b6b12f5bcb1803ca3bbf5ab96e5e6b6b96b9ab77dabe9e880b3190" dependencies = [ "cc", "cfg-if", "libc", "psm", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -6024,7 +5840,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -6035,68 +5851,25 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", -] - -[[package]] -name = "structopt" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" -dependencies = [ - "clap 2.34.0", - "lazy_static", - "structopt-derive", -] - -[[package]] -name = "structopt-derive" -version = "0.4.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" -dependencies = [ - "heck 0.3.3", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", + "syn 2.0.117", ] [[package]] name = "strum" -version = "0.26.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" - -[[package]] -name = "strum" -version = "0.27.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" - -[[package]] -name = "strum_macros" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.108", -] +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" [[package]] name = "strum_macros" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -6111,13 +5884,14 @@ dependencies = [ [[package]] name = "substrait" -version = "0.62.0" +version = "0.63.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21f1cb6d0bcd097a39fc25f7236236be29881fe122e282e4173d6d007a929927" +checksum = "e620ff4d5c02fd6f7752931aa74b16a26af66a63022cc1ad412c77edbe0bab47" dependencies = [ - "heck 0.5.0", - "pbjson", - "pbjson-build", + "heck", + "indexmap 2.14.0", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "pbjson-types", "prettyplease", "prost", @@ -6130,7 +5904,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.108", + "syn 2.0.117", "typify", "walkdir", ] @@ -6154,9 +5928,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.108" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -6180,46 +5954,34 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "sysinfo" -version = "0.37.2" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16607d5caffd1c07ce073528f9ed972d88db15dd44023fa57142963be3feb11f" +checksum = "92ab6a2f8bfe508deb3c6406578252e491d299cbbf3bc0529ecc3313aee4a52f" dependencies = [ "libc", - "memchr", - "ntapi", - "objc2-core-foundation", - "objc2-io-kit", - "windows", -] - -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - -[[package]] -name = "target-lexicon" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" + "memchr", + "ntapi", + "objc2-core-foundation", + "objc2-io-kit", + "windows", +] [[package]] name = "tempfile" -version = "3.23.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.61.0", + "windows-sys 0.52.0", ] [[package]] @@ -6230,14 +5992,14 @@ dependencies = [ "chrono-tz", "datafusion-common", "env_logger", - "rand 0.9.2", + "rand 0.9.4", ] [[package]] name = "testcontainers" -version = "0.25.2" +version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f3ac71069f20ecfa60c396316c283fbf35e6833a53dff551a31b5458da05edc" +checksum = "bfd5785b5483672915ed5fe3cddf9f546802779fc1eceff0a6fb7321fac81c1e" dependencies = [ "astral-tokio-tar", "async-trait", @@ -6246,7 +6008,10 @@ dependencies = [ "docker_credential", "either", "etcetera", + "ferroid", "futures", + "http 1.4.0", + "itertools 0.14.0", "log", "memchr", "parse-display", @@ -6258,46 +6023,36 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "ulid", "url", ] [[package]] name = "testcontainers-modules" -version = "0.13.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1966329d5bb3f89d33602d2db2da971fb839f9297dad16527abf4564e2ae0a6d" +checksum = "e5985fde5befe4ffa77a052e035e16c2da86e8bae301baa9f9904ad3c494d357" dependencies = [ "testcontainers", ] -[[package]] -name = "textwrap" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width 0.1.14", -] - [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] @@ -6322,30 +6077,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", @@ -6362,9 +6117,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", "zerovec", @@ -6382,9 +6137,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -6397,9 +6152,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.48.0" +version = "1.52.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" dependencies = [ "bytes", "libc", @@ -6409,25 +6164,25 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.61.0", + "windows-sys 0.61.2", ] [[package]] name = "tokio-macros" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "tokio-postgres" -version = "0.7.14" +version = "0.7.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156efe7fff213168257853e1dfde202eed5f487522cbbbf7d219941d753d853" +checksum = "4dd8df5ef180f6364759a6f00f7aadda4fbbac86cdee37480826a6ff9f3574ce" dependencies = [ "async-trait", "byteorder", @@ -6442,7 +6197,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.2", + "rand 0.10.1", "socket2", "tokio", "tokio-util", @@ -6451,9 +6206,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.3" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ "rustls", "tokio", @@ -6461,20 +6216,21 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] name = "tokio-util" -version = "0.7.16" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -6485,20 +6241,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.2" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f1085dec27c2b6632b04c80b3bb1b4300d6495d1e129693bdda7d91e72eec1" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.23.6" +version = "0.25.11+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3effe7c0e86fdff4f69cdd2ccc1b96f933e24811c5441d44904e8683e27184b" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.14.0", "toml_datetime", "toml_parser", "winnow", @@ -6506,25 +6262,25 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.3" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cf893c33be71572e0e9aa6dd15e6677937abd686b066eac3f8cd3531688a627" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ "winnow", ] [[package]] name = "tonic" -version = "0.14.2" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", "base64 0.22.1", "bytes", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", @@ -6544,9 +6300,9 @@ dependencies = [ [[package]] name = "tonic-prost" -version = "0.14.2" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" dependencies = [ "bytes", "prost", @@ -6555,13 +6311,13 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.0", + "indexmap 2.14.0", "pin-project-lite", "slab", "sync_wrapper", @@ -6574,14 +6330,14 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags 2.9.4", + "bitflags", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "iri-string", "pin-project-lite", @@ -6604,9 +6360,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -6615,20 +6371,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -6647,9 +6403,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -6665,44 +6421,17 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tstr" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f8e0294f14baae476d0dd0a2d780b2e24d66e349a9de876f5126777a37bdba7" -dependencies = [ - "tstr_proc_macros", -] - -[[package]] -name = "tstr_proc_macros" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e78122066b0cb818b8afd08f7ed22f7fdbc3e90815035726f0840d0d26c0747a" - [[package]] name = "twox-hash" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" -[[package]] -name = "typed-arena" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" - [[package]] name = "typenum" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" - -[[package]] -name = "typewit" -version = "1.14.2" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "typify" @@ -6720,7 +6449,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1eb359f7ffa4f9ebe947fa11a1b2da054564502968db5f317b7e37693cb2240" dependencies = [ - "heck 0.5.0", + "heck", "log", "proc-macro2", "quote", @@ -6729,7 +6458,7 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.108", + "syn 2.0.117", "thiserror", "unicode-ident", ] @@ -6747,20 +6476,10 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.108", + "syn 2.0.117", "typify-impl", ] -[[package]] -name = "ulid" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" -dependencies = [ - "rand 0.9.2", - "web-time", -] - [[package]] name = "unicode-bidi" version = "0.3.18" @@ -6769,30 +6488,30 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.19" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-normalization" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" dependencies = [ "tinyvec", ] [[package]] name = "unicode-properties" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-width" @@ -6802,21 +6521,21 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode-width" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" [[package]] -name = "unindent" -version = "0.2.4" +name = "unicode-xid" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "unit-prefix" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" [[package]] name = "unsafe-libyaml" @@ -6832,43 +6551,42 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "3.1.2" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99ba1025f18a4a3fc3e9b48c868e9beb4f24f4b4b1a325bada26bd4119f46537" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" dependencies = [ "base64 0.22.1", "log", "percent-encoding", "rustls", - "rustls-pemfile", "rustls-pki-types", "ureq-proto", - "utf-8", - "webpki-roots", + "utf8-zero", ] [[package]] name = "ureq-proto" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b4531c118335662134346048ddb0e54cc86bd7e81866757873055f0e38f5d2" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" dependencies = [ "base64 0.22.1", - "http 1.3.1", + "http 1.4.0", "httparse", "log", ] [[package]] name = "url" -version = "2.5.7" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] @@ -6878,10 +6596,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] -name = "utf-8" -version = "0.7.6" +name = "utf8-zero" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" [[package]] name = "utf8_iter" @@ -6897,13 +6615,12 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.18.1" +version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ - "getrandom 0.3.4", + "getrandom 0.4.2", "js-sys", - "serde", "wasm-bindgen", ] @@ -6950,26 +6667,47 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" -version = "1.0.1+wasi-0.2.4" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] name = "wasite" -version = "0.1.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" dependencies = [ "cfg-if", "once_cell", @@ -6980,22 +6718,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" dependencies = [ - "cfg-if", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7003,48 +6738,85 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.55" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfc379bfb624eb59050b509c13e77b4eb53150c350db69628141abce842f2373" +checksum = "6bb55e2540ad1c56eec35fd63e2aea15f83b11ce487fd2de9ad11578dfc047ea" dependencies = [ + "async-trait", + "cast", "js-sys", + "libm", "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", ] [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.55" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "085b2df989e1e6f9620c1311df6c996e83fe16f57792b272ce1e024ac16a90f1" +checksum = "caf0ca1bd612b988616bac1ab34c4e4290ef18f7148a1d8b7f31c150080e9295" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.118" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23cda5ecc67248c48d3e705d3e03e00af905769b78b9d2a1678b663b8b9d4472" + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.14.0", + "wasm-encoder", + "wasmparser", ] [[package]] @@ -7060,11 +6832,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap 2.14.0", + "semver", +] + [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" dependencies = [ "js-sys", "wasm-bindgen", @@ -7080,22 +6864,15 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-roots" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b130c0d2d49f8b6889abc456e795e82525204f27c42cf767cf0d7734e089b8" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "whoami" -version = "1.6.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" dependencies = [ + "libc", "libredox", + "objc2-system-configuration", "wasite", "web-sys", ] @@ -7122,7 +6899,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.0", + "windows-sys 0.52.0", ] [[package]] @@ -7133,141 +6910,103 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.61.3" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" dependencies = [ "windows-collections", - "windows-core 0.61.2", + "windows-core", "windows-future", - "windows-link 0.1.3", "windows-numerics", ] [[package]] name = "windows-collections" -version = "0.2.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" dependencies = [ - "windows-core 0.61.2", -] - -[[package]] -name = "windows-core" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link 0.1.3", - "windows-result 0.3.4", - "windows-strings 0.4.2", + "windows-core", ] [[package]] name = "windows-core" -version = "0.62.0" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57fe7168f7de578d2d8a05b07fd61870d2e73b4020e9f49aa00da8471723497c" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.0", - "windows-result 0.4.0", - "windows-strings 0.5.0", + "windows-link", + "windows-result", + "windows-strings", ] [[package]] name = "windows-future" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", + "windows-core", + "windows-link", "windows-threading", ] [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - -[[package]] -name = "windows-link" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-numerics" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" -dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", -] - -[[package]] -name = "windows-result" -version = "0.3.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" dependencies = [ - "windows-link 0.1.3", + "windows-core", + "windows-link", ] [[package]] name = "windows-result" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f" -dependencies = [ - "windows-link 0.2.0", -] - -[[package]] -name = "windows-strings" -version = "0.4.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.0", + "windows-link", ] [[package]] @@ -7279,31 +7018,22 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets 0.52.6", -] - [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.53.3", + "windows-targets 0.53.5", ] [[package]] name = "windows-sys" -version = "0.61.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e201184e40b2ede64bc2ea34968b28e33622acdbbf37104f0e4a33f7abe657aa" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.0", + "windows-link", ] [[package]] @@ -7324,28 +7054,28 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.3" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.1.3", - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] name = "windows-threading" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" dependencies = [ - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -7356,9 +7086,9 @@ checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" @@ -7368,9 +7098,9 @@ checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" @@ -7380,9 +7110,9 @@ checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" [[package]] name = "windows_i686_gnullvm" @@ -7392,9 +7122,9 @@ checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" [[package]] name = "windows_i686_msvc" @@ -7404,9 +7134,9 @@ checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" [[package]] name = "windows_x86_64_gnu" @@ -7416,9 +7146,9 @@ checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" @@ -7428,9 +7158,9 @@ checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" [[package]] name = "windows_x86_64_msvc" @@ -7440,40 +7170,119 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.13" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] [[package]] name = "wit-bindgen" -version = "0.46.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] [[package]] -name = "writeable" -version = "0.6.1" +name = "wit-bindgen" +version = "0.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" [[package]] -name = "wyz" -version = "0.5.1" +name = "wit-bindgen-core" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" dependencies = [ - "tap", + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.14.0", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap 2.14.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.14.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", ] +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + [[package]] name = "xattr" version = "1.6.1" @@ -7490,15 +7299,6 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" -[[package]] -name = "xz2" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] - [[package]] name = "yansi" version = "1.0.1" @@ -7507,11 +7307,10 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" dependencies = [ - "serde", "stable_deref_trait", "yoke-derive", "zerofrom", @@ -7519,68 +7318,68 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", "synstructure", ] [[package]] name = "zerocopy" -version = "0.8.27" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.27" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "zerofrom" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", "synstructure", ] [[package]] name = "zeroize" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" [[package]] name = "zerotrie" -version = "0.2.2" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" dependencies = [ "displaydoc", "yoke", @@ -7589,9 +7388,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.4" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ "yoke", "zerofrom", @@ -7600,20 +7399,26 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.1" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.117", ] [[package]] name = "zlib-rs" -version = "0.5.2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513" + +[[package]] +name = "zmij" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" [[package]] name = "zstd" diff --git a/Cargo.toml b/Cargo.toml index f15929b4c2b00..37734211266ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ resolver = "2" [workspace.package] authors = ["Apache DataFusion "] -edition = "2021" +edition = "2024" homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" @@ -79,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.88.0" # Define DataFusion version -version = "50.3.0" +version = "53.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -87,88 +87,103 @@ version = "50.3.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more details: https://github.com/rust-lang/cargo/issues/11329 -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } -apache-avro = { version = "0.20", default-features = false } -arrow = { version = "57.0.0", features = [ +apache-avro = { version = "0.21", default-features = false } +arrow = { version = "58.1.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "57.0.0", default-features = false } -arrow-flight = { version = "57.0.0", features = [ +arrow-avro = { version = "58.1.0", default-features = false, features = [ + "deflate", + "snappy", + "zstd", + "bzip2", + "xz", +] } +arrow-buffer = { version = "58.1.0", default-features = false } +arrow-data = { version = "58.1.0", default-features = false } +arrow-flight = { version = "58.1.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "57.0.0", default-features = false, features = [ +# Both codecs are required here to make sure that code paths like +# file-spilling have access to all compression codecs. +arrow-ipc = { version = "58.1.0", default-features = false, features = [ "lz4", + "zstd", ] } -arrow-ord = { version = "57.0.0", default-features = false } -arrow-schema = { version = "57.0.0", default-features = false } +arrow-ord = { version = "58.1.0", default-features = false } +arrow-schema = { version = "58.1.0", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" -bytes = "1.10" -chrono = { version = "0.4.42", default-features = false } -criterion = "0.7" -ctor = "0.6.1" +bytes = "1.11" +bzip2 = "0.6.1" +chrono = { version = "0.4.44", default-features = false } +criterion = "0.8" +ctor = "0.10.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "50.3.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "50.3.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "50.3.0" } -datafusion-common = { path = "datafusion/common", version = "50.3.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "50.3.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "50.3.0", default-features = false } -datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "50.3.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "50.3.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "50.3.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "50.3.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "50.3.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "50.3.0" } -datafusion-execution = { path = "datafusion/execution", version = "50.3.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "50.3.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "50.3.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "50.3.0" } -datafusion-functions = { path = "datafusion/functions", version = "50.3.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "50.3.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "50.3.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "50.3.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "50.3.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "50.3.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "50.3.0" } -datafusion-macros = { path = "datafusion/macros", version = "50.3.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "50.3.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "50.3.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "50.3.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "50.3.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "50.3.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "50.3.0" } -datafusion-proto = { path = "datafusion/proto", version = "50.3.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "50.3.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "50.3.0" } -datafusion-session = { path = "datafusion/session", version = "50.3.0" } -datafusion-spark = { path = "datafusion/spark", version = "50.3.0" } -datafusion-sql = { path = "datafusion/sql", version = "50.3.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "50.3.0" } +datafusion = { path = "datafusion/core", version = "53.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "53.1.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "53.1.0" } +datafusion-common = { path = "datafusion/common", version = "53.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "53.1.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "53.1.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "53.1.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "53.1.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "53.1.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "53.1.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "53.1.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "53.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "53.1.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "53.1.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "53.1.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "53.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "53.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "53.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "53.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "53.1.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "53.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "53.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "53.1.0" } +datafusion-macros = { path = "datafusion/macros", version = "53.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "53.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "53.1.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "53.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "53.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "53.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "53.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "53.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "53.1.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "53.1.0" } +datafusion-session = { path = "datafusion/session", version = "53.1.0" } +datafusion-spark = { path = "datafusion/spark", version = "53.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "53.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "53.1.0" } doc-comment = "0.3" env_logger = "0.11" +flate2 = "1.1.9" futures = "0.3" +glob = "0.3.0" half = { version = "2.7.0", default-features = false } -hashbrown = { version = "0.14.5", features = ["raw"] } +hashbrown = { version = "0.17.0" } hex = { version = "0.4.3" } -indexmap = "2.12.0" -insta = { version = "1.43.2", features = ["glob", "filters"] } +indexmap = "2.14.0" +insta = { version = "1.47.2", features = ["glob", "filters"] } itertools = "0.14" +itoa = "1.0" +liblzma = { version = "0.4.6", features = ["static"] } log = "^0.4" +memchr = "2.8.0" num-traits = { version = "0.2" } -object_store = { version = "0.12.4", default-features = false } +object_store = { version = "0.13.2", default-features = false } parking_lot = "0.12" -parquet = { version = "57.0.0", default-features = false, features = [ +parquet = { version = "58.1.0", default-features = false, features = [ "arrow", "async", "object_store", ] } -pbjson = { version = "0.8.0" } -pbjson-types = "0.8" +pbjson = { version = "0.9.0" } +pbjson-types = "0.9" +pin-project = "1" # Should match arrow-flight's version of prost. prost = "0.14.1" rand = "0.9" @@ -176,12 +191,18 @@ recursive = "0.1.1" regex = "1.12" rstest = "0.26.1" serde_json = "1" -sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } +sha2 = "^0.11.0" +sqlparser = { version = "0.61.0", default-features = false, features = ["std", "visitor"] } +strum = "0.28.0" +strum_macros = "0.28.0" tempfile = "3" -testcontainers = { version = "0.25.2", features = ["default"] } -testcontainers-modules = { version = "0.13" } -tokio = { version = "1.48", features = ["macros", "rt", "sync"] } +testcontainers-modules = { version = "0.15" } +tokio = { version = "1.52", features = ["macros", "rt", "sync"] } +tokio-stream = "0.1" +tokio-util = "0.7" url = "2.5.7" +uuid = "1.23" +zstd = { version = "0.13", default-features = false } [workspace.lints.clippy] # Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) @@ -191,6 +212,11 @@ or_fun_call = "warn" unnecessary_lazy_evaluations = "warn" uninlined_format_args = "warn" inefficient_to_string = "warn" +# https://github.com/apache/datafusion/issues/18503 +needless_pass_by_value = "warn" +# https://github.com/apache/datafusion/issues/18881 +allow_attributes = "warn" +assigning_clones = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = [ @@ -203,50 +229,56 @@ unused_qualifications = "deny" # -------------------- # Compilation Profiles # -------------------- -# A Cargo profile is a preset for the compiler/linker knobs that trade off: +# A Cargo profile is a preset for the compiler/linker knobs that trades off: # - Build time: how quickly code compiles and links # - Runtime performance: how fast the resulting binaries execute # - Binary size: how large the executables end up # - Debuggability: how much debug information is preserved for debugging and profiling # +# To use a profile: `cargo [ build | run | ... ] --profile ` +# # Profiles available: -# - dev: default debug build; fastest to compile, slowest to run, full debug info -# for everyday development. -# Run: cargo run -# - release: optimized build; slowest to compile, fastest to run, smallest -# binaries for public releases. -# Run: cargo run --release -# - release-nonlto: skips LTO, so it builds quicker while staying close to -# release performance. It is useful when developing performance optimizations. -# Run: cargo run --profile release-nonlto +# - dev: default debug build; fastest to compile, slowest to run, full debug info. +# For everyday development; default for "cargo [ build | test | run ]". +# - release: fully optimized build; slowest to compile, fastest to run, smallest +# binaries. For public releases; default for "cargo [ bench | install ]". +# - release-nonlto: skips LTO, so it builds much faster while staying close to +# release performance. Useful when developing performance optimizations. # - profiling: inherits release optimizations but retains debug info to support # profiling tools and flamegraphs. -# Run: cargo run --profile profiling -# - ci: derived from `dev` but disables incremental builds and strips dependency -# symbols to keep CI artifacts small and reproducible. -# Run: cargo run --profile ci +# - ci: derived from `dev` but disables debug info and incremental builds to keep +# CI artifacts small and reproducible. +# - ci-optimized: derived from `release` but enables debug assertions and uses +# less aggressive optimizations for faster builds. Used for long-running CI +# tasks. # # If you want to optimize compilation, the `compile_profile` benchmark can be useful. # See `benchmarks/README.md` for more details. [profile.release] codegen-units = 1 lto = true -strip = true # Eliminate debug information to minimize binary size +strip = true # Eliminate debug info to minimize binary size [profile.release-nonlto] -codegen-units = 16 -debug-assertions = false -incremental = false inherits = "release" +codegen-units = 16 lto = false -opt-level = 3 -overflow-checks = false -rpath = false -strip = false # Retain debug info for flamegraphs +strip = false # Retain debug info for flamegraphs + +[profile.profiling] +inherits = "release" +debug = true +strip = false + +[profile.ci-optimized] +inherits = "release" +debug-assertions = true +codegen-units = 16 +lto = "thin" [profile.ci] -debug = false inherits = "dev" +debug = false incremental = false # This rule applies to every package except workspace members (dependencies @@ -257,8 +289,3 @@ debug = false debug-assertions = false strip = "debuginfo" incremental = false - -[profile.profiling] -inherits = "release" -debug = true -strip = false diff --git a/NOTICE.txt b/NOTICE.txt index 7f3c80d606c07..0bd2d52368fea 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2025 The Apache Software Foundation +Copyright 2019-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 5191496eaafe3..630d4295bd427 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ [![Build Status][actions-badge]][actions-url] ![Commit Activity][commit-activity-badge] [![Open Issues][open-issues-badge]][open-issues-url] +[![Pending PRs][pending-pr-badge]][pending-pr-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] ![Crates.io MSRV][msrv-badge] @@ -39,6 +40,8 @@ [commit-activity-badge]: https://img.shields.io/github/commit-activity/m/apache/datafusion [open-issues-badge]: https://img.shields.io/github/issues-raw/apache/datafusion [open-issues-url]: https://github.com/apache/datafusion/issues +[pending-pr-badge]: https://img.shields.io/github/issues-search/apache/datafusion?query=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess&label=Pending%20PRs&logo=github +[pending-pr-url]: https://github.com/apache/datafusion/pulls?q=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess+sort%3Aupdated-desc [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ [msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version @@ -55,7 +58,7 @@ DataFusion is an extensible query engine written in [Rust] that uses [Apache Arrow] as its in-memory format. This crate provides libraries and binaries for developers building fast and -feature rich database and analytic systems, customized to particular workloads. +feature-rich database and analytic systems, customized for particular workloads. See [use cases] for examples. The following related subprojects target end users: - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame @@ -64,7 +67,7 @@ See [use cases] for examples. The following related subprojects target end users DataFusion. "Out of the box," -DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [Dataframe](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], +DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -81,7 +84,7 @@ See the [Architecture] section for more details. [performance]: https://benchmark.clickhouse.com/ [architecture]: https://datafusion.apache.org/contributor-guide/architecture.html -Here are links to some important information +Here are links to important resources: - [Project Site](https://datafusion.apache.org/) - [Installation](https://datafusion.apache.org/user-guide/cli/installation.html) @@ -94,8 +97,8 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. -It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. +DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your needs. See the [list of known users](https://datafusion.apache.org/user-guide/introduction.html#known-users). ## Contributing to DataFusion @@ -112,15 +115,15 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `nested_expressions`: functions for working with nested type function such as `array_to_string` +- `nested_expressions`: functions for working with nested types such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` - `encoding_expressions`: `encode` and `decode` functions - `parquet`: support for reading the [Apache Parquet] format -- `sql`: Support for sql parsing / planning +- `sql`: support for SQL parsing and planning - `regex_expressions`: regular expression functions, such as `regexp_match` -- `unicode_expressions`: Include unicode aware functions such as `character_length` +- `unicode_expressions`: include Unicode-aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL - `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. @@ -129,7 +132,6 @@ Optional features: - `avro`: support for reading the [Apache Avro] format - `backtrace`: include backtrace information in error messages - `parquet_encryption`: support for using [Parquet Modular Encryption] -- `pyarrow`: conversions between PyArrow and DataFusion types - `serde`: enable arrow-schema's `serde` feature [apache avro]: https://avro.apache.org/ diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index c35b1a7c1944f..1e59d094eb063 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -1,3 +1,5 @@ data +data_csv results venv +!sql_benchmarks/**/results/ diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index b3fd520814dbc..1815f8bc42ca3 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -25,7 +25,11 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } rust-version = { workspace = true } +publish = false +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -37,6 +41,10 @@ mimalloc_extended = ["libmimalloc-sys/extended"] [dependencies] arrow = { workspace = true } +async-trait = "0.1" +bytes = { workspace = true } +clap = { version = "4.6.0", features = ["derive", "env"] } +criterion = { workspace = true, features = ["html_reports"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } @@ -50,10 +58,14 @@ rand = { workspace = true } regex.workspace = true serde = { version = "1.0.228", features = ["derive"] } serde_json = { workspace = true } -snmalloc-rs = { version = "0.3", optional = true } -structopt = { version = "0.3", default-features = false } +snmalloc-rs = { version = "0.7", optional = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.16" } +tokio-util = { version = "0.7.17" } [dev-dependencies] datafusion-proto = { workspace = true } +tempfile = { workspace = true } + +[[bench]] +harness = false +name = "sql" diff --git a/benchmarks/README.md b/benchmarks/README.md index 8fed85fa02b80..a4ddb09e0771c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -95,7 +95,7 @@ Generate the data required for the compile profile helper (TPC-H SF=1): ./bench.sh data compile_profile ``` -Run the benchmark across all default Cargo profiles (`dev`, `release`, `ci`, `release-nonlto`): +Run the benchmark across all default Cargo profiles (`dev`, `release`, `ci`, `ci-optimized`, `release-nonlto`, `profiling`): ```shell ./bench.sh run compile_profile @@ -119,7 +119,6 @@ You can also invoke the helper directly if you need to customise arguments furth ./benchmarks/compile_profile.py --profiles dev release --data /path/to/tpch_sf1 ``` - ## Benchmark with modified configurations ### Select join algorithm @@ -147,6 +146,19 @@ To verify that datafusion picked up your configuration, run the benchmarks with ## Comparing performance of main and a branch +For TPC-H +```shell +./benchmarks/compare_tpch.sh main mybranch +``` + +For TPC-DS. +To get data in `DATA_DIR` for TPCDS, please follow instructions in `./benchmarks/bench.sh data tcpds` +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/compare_tpcds.sh main mybranch +``` + +Alternatively, you can compare manually following the example below + ```shell git checkout main @@ -228,6 +240,23 @@ Benchmark tpch_mem.json └──────────────┴──────────────┴──────────────┴───────────────┘ ``` +## Comparing performance of main and a PR + +### TPCDS + +Considering you already have TPCDS data locally + +```shell +export DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ +export PR_NUMBER=19464 +git fetch upstream pull/$PR_NUMBER/head:pr-$PR_NUMBER +git checkout main +git pull +./benchmarks/compare_tpcds.sh main pr-$PR_NUMBER +``` + +Note: if `gh` is installed, you can also run `gh pr checkout $PR_NUMBER` instead of `git fetch upstream pull/$PR_NUMBER/head:pr-$PR_NUMBER` + ### Running Benchmarks Manually Assuming data is in the `data` directory, the `tpch` benchmark can be run with a command like this: @@ -243,28 +272,11 @@ See the help for more details. You can enable `mimalloc` or `snmalloc` (to use either the mimalloc or snmalloc allocator) as features by passing them in as `--features`. For example: ```shell -cargo run --release --features "mimalloc" --bin tpch -- benchmark datafusion --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 -``` - -The benchmark program also supports CSV and Parquet input file formats and a utility is provided to convert from `tbl` -(generated by the `dbgen` utility) to CSV and Parquet. - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-parquet --format parquet +cargo run --release --features "mimalloc" --bin dfbench tpch --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 ``` Or if you want to verify and run all the queries in the benchmark, you can just run `cargo test`. -#### Sorted Conversion - -The TPCH tables generated by the dbgen utility are sorted by their first column (their primary key for most tables, the `l_orderkey` column for the `lineitem` table.) - -To preserve this sorted order information during conversion (useful for benchmarking execution on pre-sorted data) include the `--sort` flag: - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-sorted-parquet --format parquet --sort -``` - ### Comparing results between runs Any `dfbench` execution with `-o ` argument will produce a @@ -316,7 +328,6 @@ This will produce output like: └──────────────┴──────────────┴──────────────┴───────────────┘ ``` - # Benchmark Runner The `dfbench` program contains subcommands to run the various @@ -356,24 +367,28 @@ FLAGS: ``` # Profiling Memory Stats for each benchmark query + The `mem_profile` program wraps benchmark execution to measure memory usage statistics, such as peak RSS. It runs each benchmark query in a separate subprocess, capturing the child process’s stdout to print structured output. Subcommands supported by mem_profile are the subset of those in `dfbench`. -Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch +Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch, TPCDS Before running benchmarks, `mem_profile` automatically compiles the benchmark binary (`dfbench`) using `cargo build`. Note that the build profile used for `dfbench` is not tied to the profile used for running `mem_profile` itself. We can explicitly specify the desired build profile using the `--bench-profile` option (e.g. release-nonlto). By prebuilding the binary and running each query in a separate process, we can ensure accurate memory statistics. Currently, `mem_profile` only supports `mimalloc` as the memory allocator, since it relies on `mimalloc`'s API to collect memory statistics. -Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. +Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. + +The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. -The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. +Example: -Example: ```shell datafusion$ cargo run --profile release-nonlto --bin mem_profile -- --bench-profile release-nonlto tpch --path benchmarks/data/tpch_sf1 --partitions 4 --format parquet ``` + Example Output: + ``` Query Time (ms) Peak RSS Peak Commit Major Page Faults ---------------------------------------------------------------- @@ -402,19 +417,21 @@ Query Time (ms) Peak RSS Peak Commit Major Page Faults ``` ## Reported Metrics + When running benchmarks, `mem_profile` collects several memory-related statistics using the mimalloc API: -- Peak RSS (Resident Set Size): -The maximum amount of physical memory used by the process. -This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. +- Peak RSS (Resident Set Size): + The maximum amount of physical memory used by the process. + This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. - Peak Commit: -The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). -This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. + The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). + This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. - Major Page Faults: -The number of major page faults triggered during execution. -This metric is obtained from the operating system and is not mimalloc-specific. + The number of major page faults triggered during execution. + This metric is obtained from the operating system and is not mimalloc-specific. + # Writing a new benchmark ## Creating or downloading data outside of the benchmark @@ -603,6 +620,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## TPCDS + +Run the tpcds benchmark. + +For data please clone `datafusion-benchmarks` repo which contains the predefined parquet data with SF1. + +```shell +git clone https://github.com/apache/datafusion-benchmarks +``` + +Then run the benchmark with the following command: + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds +``` + +Alternatively benchmark the specific query + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds 30 +``` + +More help + +```shell +cargo run --release --bin dfbench -- tpcds --help +``` + ## External Aggregation Run the benchmark for aggregations with limited memory. @@ -762,7 +807,7 @@ Different queries are included to test nested loop joins under various workloads ## Hash Join -This benchmark focuses on the performance of queries with nested hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. +This benchmark focuses on the performance of queries with hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. Several queries are included to test hash joins under various workloads. @@ -774,6 +819,19 @@ Several queries are included to test hash joins under various workloads. ./bench.sh run hj ``` +## Sort Merge Join + +This benchmark focuses on the performance of queries with sort merge joins joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test sort merge joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run smj +``` ## Cancellation Test performance of cancelling queries. @@ -804,3 +862,82 @@ Getting results... cancelling thread done dropping runtime in 83.531417ms ``` + +## Sorted Data Benchmarks + +### Data Sorted ClickBench + +Benchmark for queries on pre-sorted data to test sort order optimization. +This benchmark uses a subset of the ClickBench dataset (hits.parquet, ~14GB) that has been pre-sorted by the EventTime column. The queries are designed to test DataFusion's performance when the data is already sorted as is common in timeseries workloads. + +The benchmark includes queries that: +- Scan pre-sorted data with ORDER BY clauses that match the sort order +- Test reverse scans on sorted data +- Verify the performance result + +#### Generating Sorted Data + +The sorted dataset is automatically generated from the ClickBench partitioned dataset. You can configure the memory used during the sorting process with the `DATAFUSION_MEMORY_GB` environment variable. The default memory limit is 12GB. +```bash +./bench.sh data clickbench_sorted +``` + +To create the sorted dataset, for example with 16GB of memory, run: + +```bash +DATAFUSION_MEMORY_GB=16 ./bench.sh data clickbench_sorted +``` + +This command will: +1. Download the ClickBench partitioned dataset if not present +2. Sort hits.parquet by EventTime in ascending order +3. Save the sorted file as hits_sorted.parquet + +#### Running the Benchmark + +```bash +./bench.sh run clickbench_sorted +``` + +This runs queries against the pre-sorted dataset with the `--sorted-by EventTime` flag, which informs DataFusion that the data is pre-sorted, allowing it to optimize away redundant sort operations. + +## Sort Pushdown + +Benchmarks for sort pushdown optimizations on TPC-H lineitem data (SF=1). + +### Variants + +| Benchmark | Description | +|-----------|-------------| +| `sort_pushdown` | Baseline — no `WITH ORDER`, tests standard sort behavior | +| `sort_pushdown_sorted` | With `WITH ORDER` — tests sort elimination on sorted files | +| `sort_pushdown_inexact` | Inexact path (`--sorted` DESC) — multi-file with scrambled RGs, tests reverse scan + RG reorder | +| `sort_pushdown_inexact_unsorted` | No `WITH ORDER` — same data, tests Unsupported path + RG reorder | +| `sort_pushdown_inexact_overlap` | Multi-file scrambled RGs — streaming data scenario | + +### Queries + +**sort_pushdown / sort_pushdown_sorted** (q1-q8): +- q1-q4: ASC queries (sort elimination with `--sorted`) +- q5-q8: DESC LIMIT queries (reverse scan + TopK optimization with `--sorted`) + +**sort_pushdown_inexact** (q1-q4): DESC LIMIT queries on scrambled data + +### Data Generation + +The inexact/overlap data requires pyarrow (`pip install pyarrow`) to generate +multi-file parquet with scrambled row group order. DataFusion's COPY cannot produce +narrow-range RGs in scrambled order because the parquet writer merges rows from +adjacent chunks at RG boundaries. + +### Running + +```bash +# Generate data and run all sort pushdown benchmarks +./bench.sh data sort_pushdown +./bench.sh data sort_pushdown_inexact +./bench.sh run sort_pushdown +./bench.sh run sort_pushdown_sorted +./bench.sh run sort_pushdown_inexact +./bench.sh run sort_pushdown_inexact_overlap +``` diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index dbfd319dd9ad4..7aa0418e1d74d 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -41,8 +41,15 @@ BENCHMARK=all DATAFUSION_DIR=${DATAFUSION_DIR:-$SCRIPT_DIR/..} DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} +SQL_CARGO_COMMAND=${SQL_CARGO_COMMAND:-"cargo bench --bench sql"} PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} -VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} +SIMULATE_LATENCY=${SIMULATE_LATENCY:-false} + +# Build latency arg based on SIMULATE_LATENCY setting +LATENCY_ARG="" +if [ "$SIMULATE_LATENCY" = "true" ]; then + LATENCY_ARG="--simulate-latency" +fi usage() { echo " @@ -53,7 +60,6 @@ $0 data [benchmark] $0 run [benchmark] [query] $0 compare $0 compare_detail -$0 venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Examples: @@ -71,7 +77,6 @@ data: Generates or downloads data needed for benchmarking run: Runs the named benchmark compare: Compares fastest results from benchmark runs compare_detail: Compares minimum, average (±stddev), and maximum results from benchmark runs -venv: Creates new venv (unless already exists) and installs compare's requirements into it ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Benchmarks @@ -87,6 +92,9 @@ tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), tpch_csv10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single csv file per table, hash join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory +# TPC-DS Benchmarks +tpcds: TPCDS inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join + # Extended TPC-H Benchmarks sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=1) sort_tpch10: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=10) @@ -99,6 +107,16 @@ clickbench_partitioned: ClickBench queries against partitioned (100 files) parqu clickbench_pushdown: ClickBench queries against partitioned (100 files) parquet w/ filter_pushdown enabled clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +# Sort Pushdown Benchmarks +sort_pushdown: Sort pushdown baseline (no WITH ORDER) on TPC-H data (SF=1) +sort_pushdown_sorted: Sort pushdown with WITH ORDER — tests sort elimination on non-overlapping files +sort_pushdown_inexact: Sort pushdown Inexact path (--sorted DESC) — multi-file with scrambled RGs, tests reverse scan + RG reorder +sort_pushdown_inexact_unsorted: Sort pushdown Inexact path (no WITH ORDER) — same data, tests Unsupported path + RG reorder +sort_pushdown_inexact_overlap: Sort pushdown Inexact path — multi-file scrambled RGs (streaming data scenario) + +# Sorted Data Benchmarks (ORDER BY Optimization) +clickbench_sorted: ClickBench queries on pre-sorted data using prefer_existing_sort (tests sort elimination optimization) + # H2O.ai Benchmarks (Group By, Join, Window) h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv @@ -126,6 +144,7 @@ imdb: Join Order Benchmark (JOB) using the IMDB dataset conver cancellation: How long cancelling a query takes nlj: Benchmark for simple nested loop joins, testing various join scenarios hj: Benchmark for simple hash joins, testing various join scenarios +smj: Benchmark for simple sort merge joins, testing various join scenarios compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size @@ -137,7 +156,7 @@ CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) -VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) +SIMULATE_LATENCY Simulate object store latency to mimic S3 (default false) DATAFUSION_* Set the given datafusion configuration " exit 1 @@ -189,8 +208,8 @@ main() { echo "***************************" case "$BENCHMARK" in all) - data_tpch "1" - data_tpch "10" + data_tpch "1" "parquet" + data_tpch "10" "parquet" data_h2o "SMALL" data_h2o "MEDIUM" data_h2o "BIG" @@ -203,18 +222,25 @@ main() { # nlj uses range() function, no data generation needed ;; tpch) - data_tpch "1" + data_tpch "1" "parquet" ;; tpch_mem) - # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" + ;; + tpch_csv) + data_tpch "1" "csv" ;; tpch10) - data_tpch "10" + data_tpch "10" "parquet" ;; tpch_mem10) - # same data as for tpch10 - data_tpch "10" + data_tpch "10" "parquet" + ;; + tpch_csv10) + data_tpch "10" "csv" + ;; + tpcds) + data_tpcds ;; clickbench_1) data_clickbench_1 @@ -289,30 +315,42 @@ main() { ;; external_aggr) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" + ;; + sort_pushdown|sort_pushdown_sorted) + data_sort_pushdown + ;; + sort_pushdown_inexact|sort_pushdown_inexact_unsorted|sort_pushdown_inexact_overlap) + data_sort_pushdown_inexact ;; sort_tpch) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; sort_tpch10) # same data as for tpch10 - data_tpch "10" + data_tpch "10" "parquet" ;; topk_tpch) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; nlj) # nlj uses range() function, no data generation needed echo "NLJ benchmark does not require data generation" ;; hj) - # hj uses range() function, no data generation needed - echo "HJ benchmark does not require data generation" + data_tpch "10" "parquet" + ;; + smj) + # smj uses range() function, no data generation needed + echo "SMJ benchmark does not require data generation" ;; compile_profile) - data_tpch "1" + data_tpch "1" "parquet" + ;; + clickbench_sorted) + clickbench_sorted ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" @@ -355,6 +393,7 @@ main() { echo "RESULTS_DIR: ${RESULTS_DIR}" echo "CARGO_COMMAND: ${CARGO_COMMAND}" echo "PREFER_HASH_JOIN: ${PREFER_HASH_JOIN}" + echo "SIMULATE_LATENCY: ${SIMULATE_LATENCY}" echo "***************************" # navigate to the appropriate directory @@ -384,6 +423,8 @@ main() { run_external_aggr run_nlj run_hj + run_tpcds + run_smj ;; tpch) run_tpch "1" "parquet" @@ -403,6 +444,9 @@ main() { tpch_mem10) run_tpch_mem "10" ;; + tpcds) + run_tpcds + ;; cancellation) run_cancellation ;; @@ -445,7 +489,7 @@ main() { h2o_medium_window) run_h2o_window "MEDIUM" "CSV" "window" ;; - h2o_big_window) + h2o_big_window) run_h2o_window "BIG" "CSV" "window" ;; h2o_small_parquet) @@ -479,6 +523,21 @@ main() { external_aggr) run_external_aggr ;; + sort_pushdown) + run_sort_pushdown + ;; + sort_pushdown_sorted) + run_sort_pushdown_sorted + ;; + sort_pushdown_inexact) + run_sort_pushdown_inexact + ;; + sort_pushdown_inexact_unsorted) + run_sort_pushdown_inexact_unsorted + ;; + sort_pushdown_inexact_overlap) + run_sort_pushdown_inexact_overlap + ;; sort_tpch) run_sort_tpch "1" ;; @@ -494,9 +553,15 @@ main() { hj) run_hj ;; + smj) + run_smj + ;; compile_profile) run_compile_profile "${PROFILE_ARGS[@]}" ;; + clickbench_sorted) + run_clickbench_sorted + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -511,9 +576,6 @@ main() { compare_detail) compare_benchmarks "$ARG2" "$ARG3" "--detailed" ;; - venv) - setup_venv - ;; "") usage ;; @@ -529,7 +591,7 @@ main() { # Creates TPCH data at a certain scale factor, if it doesn't already # exist # -# call like: data_tpch($scale_factor) +# call like: data_tpch($scale_factor, format) # # Creates data in $DATA_DIR/tpch_sf1 for scale factor 1 # Creates data in $DATA_DIR/tpch_sf10 for scale factor 10 @@ -540,20 +602,23 @@ data_tpch() { echo "Internal error: Scale factor not specified" exit 1 fi + FORMAT=$2 + if [ -z "$FORMAT" ] ; then + echo "Internal error: Format not specified" + exit 1 + fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - echo "Creating tpch dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." + echo "Creating tpch $FORMAT dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." # Ensure the target data directory exists mkdir -p "${TPCH_DIR}" - # Create 'tbl' (CSV format) data into $DATA_DIR if it does not already exist - FILE="${TPCH_DIR}/supplier.tbl" - if test -f "${FILE}"; then - echo " tbl files exist ($FILE exists)." - else - echo " creating tbl files with tpch_dbgen..." - docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s "${SCALE_FACTOR}" + # check if tpchgen-cli is installed + if ! command -v tpchgen-cli &> /dev/null + then + echo "tpchgen-cli could not be found, please install it via 'cargo install tpchgen-cli'" + exit 1 fi # Copy expected answers into the ./data/answers directory if it does not already exist @@ -566,27 +631,52 @@ data_tpch() { docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi - # Create 'parquet' files from tbl - FILE="${TPCH_DIR}/supplier" - if test -d "${FILE}"; then - echo " parquet files exist ($FILE exists)." - else - echo " creating parquet files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet - popd > /dev/null + if [ "$FORMAT" = "parquet" ]; then + # Create 'parquet' files, one directory per file + FILE="${TPCH_DIR}/supplier" + if test -d "${FILE}"; then + echo " parquet files exist ($FILE exists)." + else + echo " creating parquet files using tpchgen-cli ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format parquet --parquet-compression='ZSTD(1)' --parts=1 --output-dir "${TPCH_DIR}" + fi + return fi - # Create 'csv' files from tbl - FILE="${TPCH_DIR}/csv/supplier" - if test -d "${FILE}"; then - echo " csv files exist ($FILE exists)." - else - echo " creating csv files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv - popd > /dev/null + # Create 'csv' files, one directory per file + if [ "$FORMAT" = "csv" ]; then + FILE="${TPCH_DIR}/csv/supplier" + if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." + else + echo " creating csv files using tpchgen-cli binary ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format csv --parts=1 --output-dir "${TPCH_DIR}/csv" + fi + return + fi + + echo "Error: unknown format '$FORMAT' for tpch data generation, expected 'parquet' or 'csv'" + exit 1 +} + +# Downloads TPC-DS data +data_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" + + # Check if `web_site.parquet` exists in the TPCDS data directory to verify data presence + echo "Checking TPC-DS data directory: ${TPCDS_DIR}" + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + mkdir -p "${TPCDS_DIR}" + # Download the DataFusion benchmarks repository zip if it is not already downloaded + if [ ! -f "${DATA_DIR}/datafusion-benchmarks.zip" ]; then + echo "Downloading DataFusion benchmarks repository zip to: ${DATA_DIR}/datafusion-benchmarks.zip" + wget --timeout=30 --tries=3 -O "${DATA_DIR}/datafusion-benchmarks.zip" https://github.com/apache/datafusion-benchmarks/archive/refs/heads/main.zip + fi + echo "Extracting TPC-DS parquet data to ${TPCDS_DIR}..." + unzip -o -j -d "${TPCDS_DIR}" "${DATA_DIR}/datafusion-benchmarks.zip" datafusion-benchmarks-main/tpcds/data/sf1/* + echo "TPC-DS data extracted." fi + echo "Done." } # Runs the tpch benchmark @@ -596,30 +686,54 @@ run_tpch() { echo "Internal error: Scale factor not specified" exit 1 fi - TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - - RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" - echo "RESULTS_FILE: ${RESULTS_FILE}" + FORMAT=$2 echo "Running tpch benchmark..." - FORMAT=$2 - debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run env BENCH_NAME=tpch \ + BENCH_SIZE="${SCALE_FACTOR}" \ + PREFER_HASH_JOIN="${PREFER_HASH_JOIN}" \ + TPCH_FILE_TYPE="${FORMAT}" \ + SIMULATE_LATENCY="${SIMULATE_LATENCY}" \ + ${QUERY:+BENCH_QUERY="${QUERY}"} \ + bash -c "$SQL_CARGO_COMMAND" } -# Runs the tpch in memory +# Runs the tpch in memory (needs tpch parquet data) run_tpch_mem() { SCALE_FACTOR=$1 if [ -z "$SCALE_FACTOR" ] ; then echo "Internal error: Scale factor not specified" exit 1 fi - TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + echo "Running tpch_mem benchmark..." + + debug_run env BENCH_NAME=tpch \ + BENCH_SIZE="${SCALE_FACTOR}" \ + TPCH_FILE_TYPE="mem" \ + PREFER_HASH_JOIN="${PREFER_HASH_JOIN}" \ + SIMULATE_LATENCY="${SIMULATE_LATENCY}" \ + ${QUERY:+BENCH_QUERY="${QUERY}"} \ + bash -c "$SQL_CARGO_COMMAND" +} + +# Runs the tpcds benchmark +run_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" - RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" + # Check if TPCDS data directory and representative file exists + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + echo "" >&2 + echo "Please prepare TPC-DS data first by following instructions:" >&2 + echo " ./bench.sh data tpcds" >&2 + echo "" >&2 + exit 1 + fi + + RESULTS_FILE="${RESULTS_DIR}/tpcds_sf1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running tpch_mem benchmark..." - # -m means in memory - debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} + echo "Running tpcds benchmark..." + + debug_run $CARGO_COMMAND --bin dfbench -- tpcds --iterations 5 --path "${TPCDS_DIR}" --query_path "../datafusion/core/tests/tpc-ds" --prefer_hash_join "${PREFER_HASH_JOIN}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Runs the compile profile benchmark helper @@ -629,7 +743,7 @@ run_compile_profile() { local data_path="${DATA_DIR}/tpch_sf1" echo "Running compile profile benchmark..." - local cmd=(python3 "${runner}" --data "${data_path}") + local cmd=(uv run python3 "${runner}" --data "${data_path}") if [ ${#profiles[@]} -gt 0 ]; then cmd+=(--profiles "${profiles[@]}") fi @@ -641,7 +755,7 @@ run_cancellation() { RESULTS_FILE="${RESULTS_DIR}/cancellation.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running cancellation benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" ${LATENCY_ARG} } @@ -695,7 +809,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Runs the clickbench benchmark with the partitioned parquet dataset (100 files) @@ -703,7 +817,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } @@ -712,7 +826,7 @@ run_clickbench_pushdown() { RESULTS_FILE="${RESULTS_DIR}/clickbench_pushdown.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark with pushdown_filters=true, reorder_filters=true..." - debug_run $CARGO_COMMAND --bin dfbench -- clickbench --pushdown --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --pushdown --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } @@ -721,7 +835,7 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended" -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) @@ -806,7 +920,7 @@ data_imdb() { if [ "${DOWNLOADED_SIZE}" != "${expected_size}" ]; then echo "Error: Download size mismatch" echo "Expected: ${expected_size}" - echo "Got: ${DOWNLADED_SIZE}" + echo "Got: ${DOWNLOADED_SIZE}" echo "Please re-initiate the download" return 1 fi @@ -836,7 +950,7 @@ run_imdb() { RESULTS_FILE="${RESULTS_DIR}/imdb.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running imdb benchmark..." - debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } data_h2o() { @@ -844,75 +958,13 @@ data_h2o() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } data_h2o_join() { @@ -920,75 +972,13 @@ data_h2o_join() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } # Runner for h2o groupby benchmark @@ -1032,7 +1022,7 @@ run_h2o() { --path "${H2O_DIR}/${FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ -o "${RESULTS_FILE}" \ - ${QUERY_ARG} + ${QUERY_ARG} ${LATENCY_ARG} } # Utility function to run h2o join/window benchmark @@ -1084,7 +1074,7 @@ h2o_runner() { --join-paths "${H2O_DIR}/${X_TABLE_FILE_NAME},${H2O_DIR}/${SMALL_TABLE_FILE_NAME},${H2O_DIR}/${MEDIUM_TABLE_FILE_NAME},${H2O_DIR}/${LARGE_TABLE_FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ -o "${RESULTS_FILE}" \ - ${QUERY_ARG} + ${QUERY_ARG} ${LATENCY_ARG} } # Runners for h2o join benchmark @@ -1113,6 +1103,241 @@ run_external_aggr() { debug_run $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} } +# Runs the sort pushdown benchmark (without WITH ORDER) +# Generates sort pushdown benchmark data: TPC-H lineitem with 3 parts, +# renamed so alphabetical order does NOT match sort key order. +# This forces the sort pushdown optimizer to reorder files by statistics. +# +# tpchgen produces 3 sorted, non-overlapping parquet files: +# lineitem.1.parquet: l_orderkey 1 ~ 2M (lowest keys) +# lineitem.2.parquet: l_orderkey 2M ~ 4M +# lineitem.3.parquet: l_orderkey 4M ~ 6M (highest keys) +# +# We rename them so alphabetical order is reversed: +# a_part3.parquet (highest keys, sorts first alphabetically) +# b_part2.parquet +# c_part1.parquet (lowest keys, sorts last alphabetically) +data_sort_pushdown() { + SORT_PUSHDOWN_DIR="${DATA_DIR}/sort_pushdown/lineitem" + if [ -d "${SORT_PUSHDOWN_DIR}" ] && [ "$(ls -A ${SORT_PUSHDOWN_DIR}/*.parquet 2>/dev/null)" ]; then + echo "Sort pushdown data already exists at ${SORT_PUSHDOWN_DIR}" + return + fi + + echo "Generating sort pushdown benchmark data (3 parts with reversed naming)..." + + TEMP_DIR="${DATA_DIR}/sort_pushdown_temp" + mkdir -p "${TEMP_DIR}" "${SORT_PUSHDOWN_DIR}" + + tpchgen-cli --scale-factor 1 --format parquet --parquet-compression='ZSTD(1)' --parts=3 --output-dir "${TEMP_DIR}" + + # Rename: reverse alphabetical order vs key order + mv "${TEMP_DIR}/lineitem/lineitem.3.parquet" "${SORT_PUSHDOWN_DIR}/a_part3.parquet" + mv "${TEMP_DIR}/lineitem/lineitem.2.parquet" "${SORT_PUSHDOWN_DIR}/b_part2.parquet" + mv "${TEMP_DIR}/lineitem/lineitem.1.parquet" "${SORT_PUSHDOWN_DIR}/c_part1.parquet" + + rm -rf "${TEMP_DIR}" + + echo "Sort pushdown data generated at ${SORT_PUSHDOWN_DIR}" + ls -la "${SORT_PUSHDOWN_DIR}" +} + +run_sort_pushdown() { + SORT_PUSHDOWN_DIR="${DATA_DIR}/sort_pushdown" + RESULTS_FILE="${RESULTS_DIR}/sort_pushdown.json" + echo "Running sort pushdown benchmark (no WITH ORDER)..." + debug_run $CARGO_COMMAND --bin dfbench -- sort-pushdown --iterations 5 --path "${SORT_PUSHDOWN_DIR}" --queries-path "${SCRIPT_DIR}/queries/sort_pushdown" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the sort pushdown benchmark with WITH ORDER (enables sort elimination) +run_sort_pushdown_sorted() { + SORT_PUSHDOWN_DIR="${DATA_DIR}/sort_pushdown" + RESULTS_FILE="${RESULTS_DIR}/sort_pushdown_sorted.json" + echo "Running sort pushdown benchmark (with WITH ORDER)..." + debug_run $CARGO_COMMAND --bin dfbench -- sort-pushdown --sorted --iterations 5 --path "${SORT_PUSHDOWN_DIR}" --queries-path "${SCRIPT_DIR}/queries/sort_pushdown" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Generates data for sort pushdown Inexact benchmark. +# +# Produces multiple parquet files where each file has MULTIPLE row groups +# with scrambled RG order. This tests both: +# - Row-group-level reorder within each file (reorder_by_statistics) +# - TopK threshold initialization from RG statistics +# +# Strategy: +# 1. Write a single sorted file with small (100K-row) RGs (~61 RGs total). +# 2. Use pyarrow to redistribute RGs into N_FILES files, scrambling the +# RG order within each file using a deterministic permutation. +# Each file gets ~61/N_FILES RGs with narrow, non-overlapping ranges +# but in scrambled order. +# +# Writing a single file with ORDER BY scramble does NOT work: the parquet +# writer merges rows from adjacent chunks at RG boundaries, widening +# ranges and defeating reorder_by_statistics. +# +# Requires pyarrow (pip install pyarrow). +data_sort_pushdown_inexact() { + INEXACT_DIR="${DATA_DIR}/sort_pushdown_inexact/lineitem" + if [ -d "${INEXACT_DIR}" ] && [ "$(ls -A ${INEXACT_DIR}/*.parquet 2>/dev/null)" ]; then + echo "Sort pushdown Inexact data already exists at ${INEXACT_DIR}" + return + fi + + # Check pyarrow dependency (needed to split/scramble RGs) + if ! python3 -c "import pyarrow" 2>/dev/null; then + echo "Error: pyarrow is required for sort pushdown Inexact data generation." + echo "Install with: pip install pyarrow" + return 1 + fi + + echo "Generating sort pushdown Inexact benchmark data (multi-file, scrambled RGs)..." + + # Re-use the sort_pushdown data as the source (generate if missing) + data_sort_pushdown + + mkdir -p "${INEXACT_DIR}" + SRC_DIR="${DATA_DIR}/sort_pushdown/lineitem" + + # Step 1: Write a single sorted file with small (100K-row) RGs + TMPFILE="${INEXACT_DIR}/_sorted_small_rgs.parquet" + (cd "${SCRIPT_DIR}/.." && cargo run --release -p datafusion-cli -- -c " + CREATE EXTERNAL TABLE src + STORED AS PARQUET + LOCATION '${SRC_DIR}'; + + COPY (SELECT * FROM src ORDER BY l_orderkey) + TO '${TMPFILE}' + STORED AS PARQUET + OPTIONS ('format.max_row_group_size' '100000'); + ") + + # Step 2: Redistribute RGs into 3 files with scrambled RG order. + # Each file gets ~20 RGs. RG assignment: rg_idx % 3 determines file, + # permutation (rg_idx * 41 + 7) % n scrambles the order within file. + python3 -c " +import pyarrow.parquet as pq + +pf = pq.ParquetFile('${TMPFILE}') +n = pf.metadata.num_row_groups +n_files = 3 + +# Assign each RG to a file, scramble order within each file +file_rgs = [[] for _ in range(n_files)] +for rg_idx in range(n): + slot = (rg_idx * 41 + 7) % n # scrambled index + file_id = slot % n_files + file_rgs[file_id].append(rg_idx) + +# Write each file with its assigned RGs (in scrambled order) +for file_id in range(n_files): + rgs = file_rgs[file_id] + if not rgs: + continue + tables = [pf.read_row_group(rg) for rg in rgs] + writer = pq.ParquetWriter( + '${INEXACT_DIR}/part_%03d.parquet' % file_id, + pf.schema_arrow) + for t in tables: + writer.write_table(t) + writer.close() + print(f'File part_{file_id:03d}.parquet: {len(rgs)} RGs') +" + + rm -f "${TMPFILE}" + echo "Sort pushdown Inexact data generated at ${INEXACT_DIR}" + ls -la "${INEXACT_DIR}" + + # Also generate overlap data: same strategy but with different file count + # and permutation. Simulates streaming data with network delays where + # chunks arrive out of sequence. + # + # Requires pyarrow (pip install pyarrow). + OVERLAP_DIR="${DATA_DIR}/sort_pushdown_inexact_overlap/lineitem" + if [ -d "${OVERLAP_DIR}" ] && [ "$(ls -A ${OVERLAP_DIR}/*.parquet 2>/dev/null)" ]; then + echo "Sort pushdown Inexact overlap data already exists at ${OVERLAP_DIR}" + return + fi + + echo "Generating sort pushdown Inexact overlap data (multi-file, scrambled RGs)..." + mkdir -p "${OVERLAP_DIR}" + + # Step 1: Write a single sorted file with small (100K-row) RGs + TMPFILE="${OVERLAP_DIR}/_sorted_small_rgs.parquet" + (cd "${SCRIPT_DIR}/.." && cargo run --release -p datafusion-cli -- -c " + CREATE EXTERNAL TABLE src + STORED AS PARQUET + LOCATION '${SRC_DIR}'; + + COPY (SELECT * FROM src ORDER BY l_orderkey) + TO '${TMPFILE}' + STORED AS PARQUET + OPTIONS ('format.max_row_group_size' '100000'); + ") + + # Step 2: Redistribute into 5 files with scrambled RG order. + python3 -c " +import pyarrow.parquet as pq + +pf = pq.ParquetFile('${TMPFILE}') +n = pf.metadata.num_row_groups +n_files = 5 + +file_rgs = [[] for _ in range(n_files)] +for rg_idx in range(n): + slot = (rg_idx * 37 + 13) % n + file_id = slot % n_files + file_rgs[file_id].append(rg_idx) + +for file_id in range(n_files): + rgs = file_rgs[file_id] + if not rgs: + continue + tables = [pf.read_row_group(rg) for rg in rgs] + writer = pq.ParquetWriter( + '${OVERLAP_DIR}/part_%03d.parquet' % file_id, + pf.schema_arrow) + for t in tables: + writer.write_table(t) + writer.close() + print(f'File part_{file_id:03d}.parquet: {len(rgs)} RGs') +" + + rm -f "${TMPFILE}" +} + +# Runs the sort pushdown Inexact benchmark (tests RG reorder by statistics). +# Enables pushdown_filters so TopK's dynamic filter is pushed to the parquet +# reader for late materialization (only needed for Inexact path). +run_sort_pushdown_inexact() { + INEXACT_DIR="${DATA_DIR}/sort_pushdown_inexact" + RESULTS_FILE="${RESULTS_DIR}/sort_pushdown_inexact.json" + echo "Running sort pushdown Inexact benchmark (multi-file scrambled RGs, --sorted DESC)..." + DATAFUSION_EXECUTION_PARQUET_PUSHDOWN_FILTERS=true \ + debug_run $CARGO_COMMAND --bin dfbench -- sort-pushdown --sorted --iterations 5 --path "${INEXACT_DIR}" --queries-path "${SCRIPT_DIR}/queries/sort_pushdown_inexact" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the sort pushdown Inexact benchmark WITHOUT declared ordering. +# Tests the Unsupported path in try_pushdown_sort where RG reorder by +# statistics can still help TopK queries without any file ordering guarantee. +run_sort_pushdown_inexact_unsorted() { + INEXACT_DIR="${DATA_DIR}/sort_pushdown_inexact" + RESULTS_FILE="${RESULTS_DIR}/sort_pushdown_inexact_unsorted.json" + echo "Running sort pushdown Inexact benchmark (no WITH ORDER, Unsupported path)..." + DATAFUSION_EXECUTION_PARQUET_PUSHDOWN_FILTERS=true \ + debug_run $CARGO_COMMAND --bin dfbench -- sort-pushdown --iterations 5 --path "${INEXACT_DIR}" --queries-path "${SCRIPT_DIR}/queries/sort_pushdown_inexact_unsorted" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the sort pushdown benchmark with multi-file scrambled RG order. +# Simulates streaming data with network delays — multiple files, each with +# scrambled RGs. Tests both RG-level reorder and TopK stats initialization. +run_sort_pushdown_inexact_overlap() { + OVERLAP_DIR="${DATA_DIR}/sort_pushdown_inexact_overlap" + RESULTS_FILE="${RESULTS_DIR}/sort_pushdown_inexact_overlap.json" + echo "Running sort pushdown Inexact benchmark (multi-file scrambled RGs, streaming data pattern)..." + DATAFUSION_EXECUTION_PARQUET_PUSHDOWN_FILTERS=true \ + debug_run $CARGO_COMMAND --bin dfbench -- sort-pushdown --sorted --iterations 5 --path "${OVERLAP_DIR}" --queries-path "${SCRIPT_DIR}/queries/sort_pushdown_inexact_overlap" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + # Runs the sort integration benchmark run_sort_tpch() { SCALE_FACTOR=$1 @@ -1125,7 +1350,7 @@ run_sort_tpch() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort tpch benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Runs the sort tpch integration benchmark with limit 100 (topk) @@ -1135,7 +1360,7 @@ run_topk_tpch() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running topk tpch benchmark..." - $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 ${QUERY_ARG} + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 ${QUERY_ARG} ${LATENCY_ARG} } # Runs the nlj benchmark @@ -1143,15 +1368,24 @@ run_nlj() { RESULTS_FILE="${RESULTS_DIR}/nlj.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running nlj benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- nlj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- nlj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Runs the hj benchmark run_hj() { + TPCH_DIR="${DATA_DIR}/tpch_sf10" RESULTS_FILE="${RESULTS_DIR}/hj.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running hj benchmark..." - debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the smj benchmark +run_smj() { + RESULTS_FILE="${RESULTS_DIR}/smj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running smj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- smj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } @@ -1181,7 +1415,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${BENCH}" echo "--------------------" - PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" + uv run python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -1189,10 +1423,113 @@ compare_benchmarks() { } -setup_venv() { - python3 -m venv "$VIRTUAL_ENV" - PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt +# Creates sorted ClickBench data from hits.parquet (full dataset) +# The data is sorted by EventTime in ascending order +# Uses datafusion-cli to reduce dependencies +clickbench_sorted() { + SORTED_FILE="${DATA_DIR}/hits_sorted.parquet" + ORIGINAL_FILE="${DATA_DIR}/hits.parquet" + + # Default memory limit is 12GB, can be overridden with DATAFUSION_MEMORY_GB env var + MEMORY_LIMIT_GB=${DATAFUSION_MEMORY_GB:-12} + + echo "Creating sorted ClickBench dataset from hits.parquet..." + echo "Configuration:" + echo " Memory limit: ${MEMORY_LIMIT_GB}G" + echo " Row group size: 64K rows" + echo " Compression: uncompressed" + + if [ ! -f "${ORIGINAL_FILE}" ]; then + echo "hits.parquet not found. Running data_clickbench_1 first..." + data_clickbench_1 + fi + + if [ -f "${SORTED_FILE}" ]; then + echo "Sorted hits.parquet already exists at ${SORTED_FILE}" + return 0 + fi + + echo "Sorting hits.parquet by EventTime (this may take several minutes)..." + + pushd "${DATAFUSION_DIR}" > /dev/null + echo "Building datafusion-cli..." + cargo build --release --bin datafusion-cli + DATAFUSION_CLI="${DATAFUSION_DIR}/target/release/datafusion-cli" + popd > /dev/null + + + START_TIME=$(date +%s) + echo "Start time: $(date '+%Y-%m-%d %H:%M:%S')" + echo "Using datafusion-cli to create sorted parquet file..." + "${DATAFUSION_CLI}" << EOF +-- Memory and performance configuration +SET datafusion.runtime.memory_limit = '${MEMORY_LIMIT_GB}G'; +SET datafusion.execution.spill_compression = 'uncompressed'; +SET datafusion.execution.sort_spill_reservation_bytes = 10485760; -- 10MB +SET datafusion.execution.batch_size = 8192; +SET datafusion.execution.target_partitions = 1; + +-- Parquet output configuration +SET datafusion.execution.parquet.max_row_group_size = 65536; +SET datafusion.execution.parquet.compression = 'uncompressed'; + +-- Execute sort and write +COPY (SELECT * FROM '${ORIGINAL_FILE}' ORDER BY "EventTime") +TO '${SORTED_FILE}' +STORED AS PARQUET; +EOF + + local result=$? + + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "End time: $(date '+%Y-%m-%d %H:%M:%S')" + + if [ $result -eq 0 ]; then + echo "✓ Successfully created sorted ClickBench dataset" + + INPUT_SIZE=$(stat -f%z "${ORIGINAL_FILE}" 2>/dev/null || stat -c%s "${ORIGINAL_FILE}" 2>/dev/null) + OUTPUT_SIZE=$(stat -f%z "${SORTED_FILE}" 2>/dev/null || stat -c%s "${SORTED_FILE}" 2>/dev/null) + INPUT_MB=$((INPUT_SIZE / 1024 / 1024)) + OUTPUT_MB=$((OUTPUT_SIZE / 1024 / 1024)) + + echo " Input: ${INPUT_MB} MB" + echo " Output: ${OUTPUT_MB} MB" + + echo "" + echo "Time Statistics:" + echo " Total duration: ${DURATION} seconds ($(printf '%02d:%02d:%02d' $((DURATION/3600)) $((DURATION%3600/60)) $((DURATION%60))))" + echo " Throughput: $((INPUT_MB / DURATION)) MB/s" + + return 0 + else + echo "✗ Error: Failed to create sorted dataset" + echo "💡 Tip: Try increasing memory with: DATAFUSION_MEMORY_GB=16 ./bench.sh data clickbench_sorted" + return 1 + fi +} + +# Runs the sorted data benchmark with prefer_existing_sort configuration +run_clickbench_sorted() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_sorted.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sorted data benchmark with prefer_existing_sort optimization..." + + # Ensure sorted data exists + clickbench_sorted + + # Run benchmark with prefer_existing_sort configuration + # This allows DataFusion to optimize away redundant sorts while maintaining parallelism + debug_run $CARGO_COMMAND --bin dfbench -- clickbench \ + --iterations 5 \ + --path "${DATA_DIR}/hits_sorted.parquet" \ + --queries-path "${SCRIPT_DIR}/queries/clickbench/queries/sorted_data" \ + --sorted-by "EventTime" \ + -c datafusion.optimizer.prefer_existing_sort=true \ + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} ${LATENCY_ARG} } + # And start the process up main diff --git a/benchmarks/benches/sql.rs b/benchmarks/benches/sql.rs new file mode 100644 index 0000000000000..eade3194d1402 --- /dev/null +++ b/benchmarks/benches/sql.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Criterion benchmark harness for SQL benchmark files under `sql_benchmarks`. +//! +//! SQL benchmarks describe setup, queries, result validation, and cleanup in +//! `.benchmark` files. Run them with `benchmarks/bench.sh` or directly with +//! Cargo, for example: `BENCH_NAME=tpch cargo bench --bench sql`. + +use clap::Parser; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion::error::Result; +use datafusion::prelude::SessionContext; +use datafusion_benchmarks::sql_benchmark::SqlBenchmark; +use datafusion_benchmarks::util::{CommonOpt, print_memory_stats}; +use datafusion_common::instant::Instant; +use log::{debug, info}; +use std::collections::BTreeMap; +use std::fs; +use std::sync::LazyLock; +use tokio::runtime::Runtime; + +static SQL_BENCHMARK_DIRECTORY: LazyLock = LazyLock::new(|| { + format!( + "{}{}{}", + env!("CARGO_MANIFEST_DIR"), + std::path::MAIN_SEPARATOR, + "sql_benchmarks" + ) +}); + +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + +#[cfg(feature = "snmalloc")] +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[cfg(feature = "mimalloc")] +#[global_allocator] +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; + +#[derive(Debug, Parser)] +#[command(ignore_errors = true)] +struct EnvParser { + #[command(flatten)] + options: CommonOpt, + + #[arg( + env = "BENCH_PERSIST_RESULTS", + long = "persist_results", + default_value = "false", + action = clap::ArgAction::SetTrue + )] + persist_results: bool, + + #[arg( + env = "BENCH_VALIDATE", + long = "validate_results", + default_value = "false", + action = clap::ArgAction::SetTrue + )] + validate: bool, + + #[arg(env = "BENCH_NAME")] + name: Option, + + #[arg(env = "BENCH_SUBGROUP")] + subgroup: Option, + + #[arg(env = "BENCH_QUERY")] + query: Option, +} + +pub fn sql(c: &mut Criterion) { + env_logger::init(); + + let start = Instant::now(); + let args = EnvParser::parse(); + let rt = make_tokio_runtime(); + + println!("Loading benchmarks..."); + + let benchmarks = rt.block_on(async { + let ctx = make_ctx(&args).expect("SessionContext creation failed"); + + load_benchmarks(&args, &ctx, &SQL_BENCHMARK_DIRECTORY) + .await + .unwrap_or_else(|err| panic!("failed load benchmarks: {err:?}")) + }); + + println!( + "Loaded benchmarks in {} ms ...", + start.elapsed().as_millis() + ); + + for (group, benchmarks) in benchmarks { + let mut group = c.benchmark_group(group); + group.sample_size(10); + group.sampling_mode(SamplingMode::Flat); + + for mut benchmark in benchmarks { + // create a context + let ctx = make_ctx(&args).expect("SessionContext creation failed"); + + // initialize the benchmark. This parses the benchmark file and does any pre-execution + // work such as loading data into tables + rt.block_on(async { + benchmark + .initialize(&ctx) + .await + .expect("initialization failed"); + + // run assertions + benchmark.assert(&ctx).await.expect("assertion failed"); + }); + + let mut name = benchmark.name().to_string(); + if !benchmark.subgroup().is_empty() { + name.push('_'); + name.push_str(benchmark.subgroup()); + } + + if args.persist_results { + handle_persist(&rt, &ctx, &name, &mut benchmark); + } else if args.validate { + handle_verify(&rt, &ctx, &name, &mut benchmark); + } else { + info!("Running benchmark {name} ..."); + + let name = name.clone(); + group.bench_function(name.clone(), |b| { + b.iter(|| handle_run(&rt, &ctx, &args, &mut benchmark, &name)) + }); + + print_memory_stats(); + + info!("Benchmark {name} completed"); + } + + // run cleanup + rt.block_on(async { + benchmark.cleanup(&ctx).await.expect("Cleanup failed"); + }); + } + + group.finish(); + } +} + +fn handle_run( + rt: &Runtime, + ctx: &SessionContext, + args: &EnvParser, + benchmark: &mut SqlBenchmark, + name: &str, +) { + rt.block_on(async { + benchmark + .run(ctx, args.validate) + .await + .unwrap_or_else(|err| panic!("Failed to run benchmark {name}: {err:?}")) + }); +} + +fn handle_persist( + rt: &Runtime, + ctx: &SessionContext, + name: &str, + benchmark: &mut SqlBenchmark, +) { + info!("Running benchmark {name} prior to persisting results ..."); + + rt.block_on(async { + info!("Persisting benchmark {name} ..."); + + benchmark + .persist(ctx) + .await + .expect("Failed to persist results"); + }); + + info!("Persisted benchmark {name} successfully"); +} + +fn handle_verify( + rt: &Runtime, + ctx: &SessionContext, + name: &str, + benchmark: &mut SqlBenchmark, +) { + info!("Verifying benchmark {name} results ..."); + + rt.block_on(async { + benchmark + .run(ctx, true) + .await + .unwrap_or_else(|err| panic!("Failed to run benchmark {name}: {err:?}")); + benchmark + .verify(ctx) + .await + .unwrap_or_else(|err| panic!("Verification failed: {err:?}")); + }); + + info!("Verified benchmark {name} results successfully"); +} + +criterion_group!(benches, sql); +criterion_main!(benches); + +fn make_tokio_runtime() -> Runtime { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() +} + +fn make_ctx(args: &EnvParser) -> Result { + let config = args.options.config()?; + let rt = args.options.build_runtime()?; + + Ok(SessionContext::new_with_config_rt(config, rt)) +} + +/// Recursively walks the directory tree starting at `path` and +/// calls the call back function for every file encountered. +pub fn list_files(path: &str, callback: &mut F) +where + F: FnMut(&str), +{ + let mut entries: Vec = + fs::read_dir(path).unwrap().filter_map(Result::ok).collect(); + entries.sort_by_key(|entry| entry.path()); + + for dir_entry in entries { + let path = dir_entry.path(); + if path.is_dir() { + // Recurse into the sub‑directory + list_files(&path.to_string_lossy(), callback); + } else { + // For files, invoke the callback with the full path as a string + let full_str = path.to_string_lossy(); + callback(&full_str); + } + } +} + +/// Loads all benchmark files in the `sql_benchmarks` directory. +/// For each file ending with `.benchmark` it creates a new +/// `SqlBenchmark` instance. +async fn load_benchmarks( + args: &EnvParser, + ctx: &SessionContext, + path: &str, +) -> Result>> { + let mut benches = BTreeMap::new(); + let mut paths = Vec::new(); + + list_files(path, &mut |path: &str| { + if path.ends_with(".benchmark") { + paths.push(path.to_string()); + } + }); + + for path in paths { + debug!("Loading benchmark from {path}"); + + let benchmark = SqlBenchmark::new(ctx, &path, &*SQL_BENCHMARK_DIRECTORY).await?; + let entries = benches + .entry(benchmark.group().to_string()) + .or_insert(vec![]); + + entries.push(benchmark); + } + + benches = filter_benchmarks(args, benches); + benches.iter_mut().for_each(|(_, benchmarks)| { + benchmarks.sort_by(|b1, b2| b1.name().cmp(b2.name())) + }); + + Ok(benches) +} + +fn filter_benchmarks( + args: &EnvParser, + benchmarks: BTreeMap>, +) -> BTreeMap> { + match &args.name { + Some(bench_name) => benchmarks + .into_iter() + .filter(|(key, _val)| key.eq_ignore_ascii_case(bench_name)) + .map(|(key, mut val)| { + if let Some(subgroup) = &args.subgroup { + val.retain(|bench| bench.subgroup().eq_ignore_ascii_case(subgroup)); + } + if let Some(query_number) = &args.query { + let padded = format!("Q{query_number:0>2}"); + val.retain(|bench| bench.name().eq_ignore_ascii_case(&padded)); + } + (key, val) + }) + .collect(), + None => benchmarks, + } +} diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 7e51a38a92c2b..9ad1de980abe8 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -154,17 +154,17 @@ def compare( baseline = BenchmarkRun.load_from_file(baseline_path) comparison = BenchmarkRun.load_from_file(comparison_path) - console = Console() + console = Console(width=200) # use basename as the column names - baseline_header = baseline_path.parent.stem - comparison_header = comparison_path.parent.stem + baseline_header = baseline_path.parent.name + comparison_header = comparison_path.parent.name table = Table(show_header=True, header_style="bold magenta") - table.add_column("Query", style="dim", width=12) - table.add_column(baseline_header, justify="right", style="dim") - table.add_column(comparison_header, justify="right", style="dim") - table.add_column("Change", justify="right", style="dim") + table.add_column("Query", style="dim", no_wrap=True) + table.add_column(baseline_header, justify="right", style="dim", no_wrap=True) + table.add_column(comparison_header, justify="right", style="dim", no_wrap=True) + table.add_column("Change", justify="right", style="dim", no_wrap=True) faster_count = 0 slower_count = 0 @@ -175,12 +175,12 @@ def compare( for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query - + base_failed = not baseline_result.success - comp_failed = not comparison_result.success + comp_failed = not comparison_result.success # If a query fails, its execution time is excluded from the performance comparison if base_failed or comp_failed: - change_text = "incomparable" + change_text = "incomparable" failure_count += 1 table.add_row( f"Q{baseline_result.query}", diff --git a/benchmarks/compare_tpcds.sh b/benchmarks/compare_tpcds.sh new file mode 100755 index 0000000000000..48331a7c7510e --- /dev/null +++ b/benchmarks/compare_tpcds.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-DS benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + echo "" + echo "Note: TPC-DS benchmarks are not currently implemented in bench.sh" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-DS benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpcds + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpcds + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/compare_tpch.sh b/benchmarks/compare_tpch.sh new file mode 100755 index 0000000000000..85e8da29ce41d --- /dev/null +++ b/benchmarks/compare_tpch.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-H benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-H benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpch + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpch + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/compile_profile.py b/benchmarks/compile_profile.py index ae51de94937bf..a85e15ddacc04 100644 --- a/benchmarks/compile_profile.py +++ b/benchmarks/compile_profile.py @@ -19,8 +19,10 @@ """Compile profile benchmark runner for DataFusion. -Builds the `tpch` benchmark binary with several Cargo profiles (e.g. `--release` or `--profile ci`), runs the full TPC-H suite against the Parquet data under `benchmarks/data/tpch_sf1`, and reports compile time, execution time, and resulting -binary size. +Builds the `dfbench` benchmark binary with several Cargo profiles +(e.g. `--release` or `--profile ci`), runs the full TPC-H suite against +the Parquet data under `benchmarks/data/tpch_sf1`, and reports compile +time, execution time, and resulting binary size. See `benchmarks/README.md` for usages. """ @@ -40,12 +42,15 @@ DEFAULT_ITERATIONS = 1 DEFAULT_FORMAT = "parquet" DEFAULT_PARTITIONS: int | None = None -TPCH_BINARY = "tpch.exe" if os.name == "nt" else "tpch" +BENCHMARK_PACKAGE = "datafusion-benchmarks" +BENCHMARK_BINARY = "dfbench.exe" if os.name == "nt" else "dfbench" PROFILE_TARGET_DIR = { "dev": "debug", "release": "release", "ci": "ci", + "ci-optimized": "ci-optimized", "release-nonlto": "release-nonlto", + "profiling": "profiling", } @@ -62,7 +67,10 @@ def parse_args() -> argparse.Namespace: "--profiles", nargs="+", default=list(PROFILE_TARGET_DIR.keys()), - help="Cargo profiles to test (default: dev release ci release-nonlto)", + help=( + "Cargo profiles to test " + "(default: dev release ci ci-optimized release-nonlto profiling)" + ), ) parser.add_argument( "--data", @@ -84,9 +92,25 @@ def timed_run(command: Iterable[str]) -> float: def cargo_build(profile: str) -> float: if profile == "dev": - command = ["cargo", "build", "--bin", "tpch"] + command = [ + "cargo", + "build", + "--package", + BENCHMARK_PACKAGE, + "--bin", + "dfbench", + ] else: - command = ["cargo", "build", "--profile", profile, "--bin", "tpch"] + command = [ + "cargo", + "build", + "--profile", + profile, + "--package", + BENCHMARK_PACKAGE, + "--bin", + "dfbench", + ] return timed_run(command) @@ -102,14 +126,13 @@ def run_benchmark(profile: str, data_path: Path) -> float: binary_dir = PROFILE_TARGET_DIR.get(profile) if not binary_dir: raise ValueError(f"unknown profile '{profile}'") - binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + binary_path = REPO_ROOT / "target" / binary_dir / BENCHMARK_BINARY if not binary_path.exists(): raise FileNotFoundError(f"compiled binary not found at {binary_path}") command = [ str(binary_path), - "benchmark", - "datafusion", + "tpch", "--iterations", str(DEFAULT_ITERATIONS), "--path", @@ -132,7 +155,7 @@ def run_benchmark(profile: str, data_path: Path) -> float: def binary_size(profile: str) -> int: binary_dir = PROFILE_TARGET_DIR[profile] - binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + binary_path = REPO_ROOT / "target" / binary_dir / BENCHMARK_BINARY return binary_path.stat().st_size diff --git a/benchmarks/lineprotocol.py b/benchmarks/lineprotocol.py index 75e09b662e3e1..40f643499f489 100644 --- a/benchmarks/lineprotocol.py +++ b/benchmarks/lineprotocol.py @@ -164,12 +164,12 @@ def lineformat( ) -> None: baseline = BenchmarkRun.load_from_file(baseline) context = baseline.context - benchamrk_str = f"benchmark,name={context.name},version={context.benchmark_version},datafusion_version={context.datafusion_version},num_cpus={context.num_cpus}" + benchmark_str = f"benchmark,name={context.name},version={context.benchmark_version},datafusion_version={context.datafusion_version},num_cpus={context.num_cpus}" for query in baseline.queries: query_str = f"query=\"{query.query}\"" timestamp = f"{query.start_time*10**9}" for iter_num, result in enumerate(query.iterations): - print(f"{benchamrk_str} {query_str},iteration={iter_num},row_count={result.row_count},elapsed_ms={result.elapsed*1000:.0f} {timestamp}\n") + print(f"{benchmark_str} {query_str},iteration={iter_num},row_count={result.row_count},elapsed_ms={result.elapsed*1000:.0f} {timestamp}\n") def main() -> None: parser = ArgumentParser() diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 0000000000000..e6a60582148ce --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "datafusion-benchmarks" +version = "0.1.0" +requires-python = ">=3.11" +# typing_extensions is an undeclared dependency of falsa +dependencies = ["rich", "falsa", "typing_extensions"] diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 877ea0e0c3192..8b3d08b128866 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -228,6 +228,41 @@ Results look like Elapsed 30.195 seconds. ``` + +### Q9-Q12: FIRST_VALUE Aggregation Performance + +These queries test the performance of the `FIRST_VALUE` aggregation function with different data types and grouping cardinalities. + +| Query | `FIRST_VALUE` Column | Column Type | Group By Column | Group By Type | Number of Groups | +|-------|----------------------|-------------|-----------------|---------------|------------------| +| Q9 | `URL` | `Utf8` | `UserID` | `Int64` | 17,630,976 | +| Q10 | `URL` | `Utf8` | `OS` | `Int16` | 91 | +| Q11 | `WatchID` | `Int64` | `UserID` | `Int64` | 17,630,976 | +| Q12 | `WatchID` | `Int64` | `OS` | `Int16` | 91 | + + +### Q13: Filter-only URL Range Match + +**Question**: "What is the sum of counter IDs for page views with URLs in the normal URL string range?" + +**Important Query Properties**: Filter-only string range match. The `URL` +column is used only by the pushed-down filter and is not projected or +aggregated. This makes the query useful for measuring optimizations that can +skip RowFilter evaluation when Parquet row group statistics prove that all rows +in a row group satisfy the string predicate. The output-side aggregation is +intentionally lightweight so the scan-time filter evaluation cost remains +visible. Run this query with Parquet filter pushdown enabled, for example +`dfbench clickbench --pushdown --query 13`. + +```sql +SELECT SUM("CounterID") AS counter_id_sum +FROM hits +WHERE "URL" < 'zzzz'; +``` + + + + ## Data Notes Here are some interesting statistics about the data used in the queries diff --git a/benchmarks/queries/clickbench/extended/q10.sql b/benchmarks/queries/clickbench/extended/q10.sql new file mode 100644 index 0000000000000..d6019de17854f --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q10.sql @@ -0,0 +1,8 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MAX(len) FROM ( + SELECT LENGTH(FIRST_VALUE("URL" ORDER BY "EventTime")) as len + FROM hits + GROUP BY "OS" +); diff --git a/benchmarks/queries/clickbench/extended/q11.sql b/benchmarks/queries/clickbench/extended/q11.sql new file mode 100644 index 0000000000000..bca38f836bb95 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q11.sql @@ -0,0 +1,8 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MAX(fv) FROM ( + SELECT FIRST_VALUE("WatchID" ORDER BY "EventTime") as fv + FROM hits + GROUP BY "UserID" +); diff --git a/benchmarks/queries/clickbench/extended/q12.sql b/benchmarks/queries/clickbench/extended/q12.sql new file mode 100644 index 0000000000000..fa062ac1f5cde --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q12.sql @@ -0,0 +1,8 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MAX(fv) FROM ( + SELECT FIRST_VALUE("WatchID" ORDER BY "EventTime") as fv + FROM hits + GROUP BY "OS" +); diff --git a/benchmarks/queries/clickbench/extended/q13.sql b/benchmarks/queries/clickbench/extended/q13.sql new file mode 100644 index 0000000000000..b76a0766566b3 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q13.sql @@ -0,0 +1,6 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT SUM("CounterID") AS counter_id_sum +FROM hits +WHERE "URL" < 'zzzz'; diff --git a/benchmarks/queries/clickbench/extended/q8.sql b/benchmarks/queries/clickbench/extended/q8.sql new file mode 100644 index 0000000000000..e580807841df5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q8.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency FROM hits GROUP BY "RegionID", "UserAgent", "OS" ORDER BY avg_latency DESC limit 10; \ No newline at end of file diff --git a/benchmarks/queries/clickbench/extended/q9.sql b/benchmarks/queries/clickbench/extended/q9.sql new file mode 100644 index 0000000000000..53952ebec2627 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q9.sql @@ -0,0 +1,8 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MAX(len) FROM ( + SELECT LENGTH(FIRST_VALUE("URL" ORDER BY "EventTime")) as len + FROM hits + GROUP BY "UserID" +); diff --git a/benchmarks/queries/clickbench/queries/sorted_data/q0.sql b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql new file mode 100644 index 0000000000000..1170a383bcb22 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql @@ -0,0 +1,3 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true +SELECT * FROM hits ORDER BY "EventTime" DESC limit 10; diff --git a/benchmarks/queries/h2o/window.sql b/benchmarks/queries/h2o/window.sql index 071540927a4cf..fa16a3de32ca5 100644 --- a/benchmarks/queries/h2o/window.sql +++ b/benchmarks/queries/h2o/window.sql @@ -109,4 +109,11 @@ SELECT id3, v2, sum(v2) OVER (PARTITION BY id2 ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between_by_id2 -FROM large; \ No newline at end of file +FROM large; + +-- Window Top-N (ROW_NUMBER top-2 per partition) +SELECT id2, largest2_v2 FROM ( + SELECT id2, v2 AS largest2_v2, + ROW_NUMBER() OVER (PARTITION BY id2 ORDER BY v2 DESC) AS order_v2 + FROM large WHERE v2 IS NOT NULL +) sub_query WHERE order_v2 <= 2; diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql index 8613fd4962837..8ac2fd90798c9 100644 --- a/benchmarks/queries/q10.sql +++ b/benchmarks/queries/q10.sql @@ -16,7 +16,7 @@ where c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate >= date '1993-10-01' - and o_orderdate < date '1994-01-01' + and o_orderdate < date '1993-10-01' + interval '3' month and l_returnflag = 'R' and c_nationkey = n_nationkey group by diff --git a/benchmarks/queries/q11.sql b/benchmarks/queries/q11.sql index c23ed1c71bfb3..9a9710d09ec35 100644 --- a/benchmarks/queries/q11.sql +++ b/benchmarks/queries/q11.sql @@ -13,7 +13,7 @@ group by ps_partkey having sum(ps_supplycost * ps_availqty) > ( select - sum(ps_supplycost * ps_availqty) * 0.0001 + sum(ps_supplycost * ps_availqty) * 0.0001 /* __TPCH_Q11_FRACTION__ */ from partsupp, supplier, @@ -24,4 +24,4 @@ group by and n_name = 'GERMANY' ) order by - value desc; \ No newline at end of file + value desc; diff --git a/benchmarks/queries/q12.sql b/benchmarks/queries/q12.sql index f8e6d960c8420..c3f4d62344701 100644 --- a/benchmarks/queries/q12.sql +++ b/benchmarks/queries/q12.sql @@ -23,8 +23,8 @@ where and l_commitdate < l_receiptdate and l_shipdate < l_commitdate and l_receiptdate >= date '1994-01-01' - and l_receiptdate < date '1995-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year group by l_shipmode order by - l_shipmode; \ No newline at end of file + l_shipmode; diff --git a/benchmarks/queries/q14.sql b/benchmarks/queries/q14.sql index d8ef6afaca9bb..6fe88c42662d0 100644 --- a/benchmarks/queries/q14.sql +++ b/benchmarks/queries/q14.sql @@ -10,4 +10,4 @@ from where l_partkey = p_partkey and l_shipdate >= date '1995-09-01' - and l_shipdate < date '1995-10-01'; \ No newline at end of file + and l_shipdate < date '1995-09-01' + interval '1' month; diff --git a/benchmarks/queries/q5.sql b/benchmarks/queries/q5.sql index 5a336b231184b..146980ccd6f76 100644 --- a/benchmarks/queries/q5.sql +++ b/benchmarks/queries/q5.sql @@ -17,8 +17,8 @@ where and n_regionkey = r_regionkey and r_name = 'ASIA' and o_orderdate >= date '1994-01-01' - and o_orderdate < date '1995-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year group by n_name order by - revenue desc; \ No newline at end of file + revenue desc; diff --git a/benchmarks/queries/q6.sql b/benchmarks/queries/q6.sql index 5806f980f8088..5a13fe7df765a 100644 --- a/benchmarks/queries/q6.sql +++ b/benchmarks/queries/q6.sql @@ -4,6 +4,6 @@ from lineitem where l_shipdate >= date '1994-01-01' - and l_shipdate < date '1995-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year and l_discount between 0.06 - 0.01 and 0.06 + 0.01 - and l_quantity < 24; \ No newline at end of file + and l_quantity < 24; diff --git a/benchmarks/queries/sort_pushdown/q1.sql b/benchmarks/queries/sort_pushdown/q1.sql new file mode 100644 index 0000000000000..f5f51a5d4043e --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q1.sql @@ -0,0 +1,6 @@ +-- Sort elimination: ORDER BY sort key ASC (full scan) +-- With --sorted: SortExec removed, sequential scan in file order +-- Without --sorted: full SortExec required +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey diff --git a/benchmarks/queries/sort_pushdown/q2.sql b/benchmarks/queries/sort_pushdown/q2.sql new file mode 100644 index 0000000000000..29a0e127cb7c6 --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q2.sql @@ -0,0 +1,7 @@ +-- Sort elimination + limit pushdown +-- With --sorted: SortExec removed + limit pushed to DataSourceExec +-- Without --sorted: TopK sort over all data +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown/q3.sql b/benchmarks/queries/sort_pushdown/q3.sql new file mode 100644 index 0000000000000..e11b48659a2a2 --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q3.sql @@ -0,0 +1,5 @@ +-- Sort elimination: wide projection (all columns) +-- Tests sort elimination benefit with larger row payload +SELECT * +FROM lineitem +ORDER BY l_orderkey diff --git a/benchmarks/queries/sort_pushdown/q4.sql b/benchmarks/queries/sort_pushdown/q4.sql new file mode 100644 index 0000000000000..99500c371991a --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q4.sql @@ -0,0 +1,5 @@ +-- Sort elimination + limit: wide projection +SELECT * +FROM lineitem +ORDER BY l_orderkey +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown/q5.sql b/benchmarks/queries/sort_pushdown/q5.sql new file mode 100644 index 0000000000000..60ad636ad3c9c --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q5.sql @@ -0,0 +1,7 @@ +-- Reverse scan: ORDER BY DESC LIMIT (narrow projection) +-- With --sorted: reverse_row_groups=true + TopK + stats init + cumulative prune +-- Without --sorted: full TopK sort over all data +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown/q6.sql b/benchmarks/queries/sort_pushdown/q6.sql new file mode 100644 index 0000000000000..d36a35a1e5a0d --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q6.sql @@ -0,0 +1,5 @@ +-- Reverse scan: ORDER BY DESC LIMIT larger fetch (narrow projection) +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown/q7.sql b/benchmarks/queries/sort_pushdown/q7.sql new file mode 100644 index 0000000000000..3e8856822d83d --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q7.sql @@ -0,0 +1,5 @@ +-- Reverse scan: wide projection + DESC LIMIT +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown/q8.sql b/benchmarks/queries/sort_pushdown/q8.sql new file mode 100644 index 0000000000000..95ba89fdd5089 --- /dev/null +++ b/benchmarks/queries/sort_pushdown/q8.sql @@ -0,0 +1,5 @@ +-- Reverse scan: wide projection + DESC LIMIT larger fetch +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact/q1.sql b/benchmarks/queries/sort_pushdown_inexact/q1.sql new file mode 100644 index 0000000000000..d772bc486a12b --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact/q1.sql @@ -0,0 +1,8 @@ +-- Inexact path: TopK + DESC LIMIT on ASC-declared file. +-- With RG reorder, the first RG read contains the highest max value, +-- so TopK's threshold tightens quickly and subsequent RGs get filtered +-- efficiently via dynamic filter pushdown. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact/q2.sql b/benchmarks/queries/sort_pushdown_inexact/q2.sql new file mode 100644 index 0000000000000..6e2bef44fc37e --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact/q2.sql @@ -0,0 +1,7 @@ +-- Inexact path: TopK + DESC LIMIT with larger fetch (1000). +-- Larger LIMIT means more row_replacements; RG reorder reduces the +-- total replacement count by tightening the threshold faster. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact/q3.sql b/benchmarks/queries/sort_pushdown_inexact/q3.sql new file mode 100644 index 0000000000000..d858ec79a67c9 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact/q3.sql @@ -0,0 +1,8 @@ +-- Inexact path: wide projection (all columns) + DESC LIMIT. +-- Shows the row-level filter benefit: with a tight threshold from the +-- first RG, subsequent RGs skip decoding non-sort columns for filtered +-- rows — bigger wins for wide tables. +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact/q4.sql b/benchmarks/queries/sort_pushdown_inexact/q4.sql new file mode 100644 index 0000000000000..bd2efc5d3b992 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact/q4.sql @@ -0,0 +1,7 @@ +-- Inexact path: wide projection + DESC LIMIT with larger fetch. +-- Combines wide-row row-level filter benefit with larger LIMIT to +-- demonstrate cumulative gains from RG reorder. +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact_overlap/q1.sql b/benchmarks/queries/sort_pushdown_inexact_overlap/q1.sql new file mode 100644 index 0000000000000..0e978bddbed03 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_overlap/q1.sql @@ -0,0 +1,7 @@ +-- Overlapping RGs: TopK + DESC LIMIT on file with partially overlapping +-- row groups (simulates streaming data with network jitter). +-- RG reorder places highest-max RG first for fastest threshold convergence. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact_overlap/q2.sql b/benchmarks/queries/sort_pushdown_inexact_overlap/q2.sql new file mode 100644 index 0000000000000..34d0a910cbf3a --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_overlap/q2.sql @@ -0,0 +1,5 @@ +-- Overlapping RGs: DESC LIMIT with larger fetch. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact_overlap/q3.sql b/benchmarks/queries/sort_pushdown_inexact_overlap/q3.sql new file mode 100644 index 0000000000000..08b30b24d3dd1 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_overlap/q3.sql @@ -0,0 +1,6 @@ +-- Overlapping RGs: wide projection + DESC LIMIT. +-- Row-level filter benefit: tight threshold skips decoding non-sort columns. +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact_overlap/q4.sql b/benchmarks/queries/sort_pushdown_inexact_overlap/q4.sql new file mode 100644 index 0000000000000..4c091424f901c --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_overlap/q4.sql @@ -0,0 +1,5 @@ +-- Overlapping RGs: wide projection + DESC LIMIT larger fetch. +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q1.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q1.sql new file mode 100644 index 0000000000000..06748b72a98a3 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q1.sql @@ -0,0 +1,7 @@ +-- Unsupported path: TopK + ASC LIMIT on file without declared ordering. +-- Tests RG reorder benefit when no WITH ORDER is declared — the +-- Unsupported path in try_pushdown_sort triggers RG reorder. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q2.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q2.sql new file mode 100644 index 0000000000000..384e4647eb0d9 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q2.sql @@ -0,0 +1,5 @@ +-- Unsupported path: TopK + ASC LIMIT with larger fetch. +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q3.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q3.sql new file mode 100644 index 0000000000000..d48a2d969c468 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q3.sql @@ -0,0 +1,6 @@ +-- Unsupported path: wide projection + ASC LIMIT. +-- Shows row-level filter benefit when RG reorder tightens TopK threshold. +SELECT * +FROM lineitem +ORDER BY l_orderkey +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q4.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q4.sql new file mode 100644 index 0000000000000..d12d48f43a626 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q4.sql @@ -0,0 +1,5 @@ +-- Unsupported path: wide projection + ASC LIMIT with larger fetch. +SELECT * +FROM lineitem +ORDER BY l_orderkey +LIMIT 1000 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q5.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q5.sql new file mode 100644 index 0000000000000..ab1dddab408f3 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q5.sql @@ -0,0 +1,5 @@ +-- Unsupported path: DESC LIMIT (no declared ordering = no reverse scan). +SELECT l_orderkey, l_partkey, l_suppkey +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/queries/sort_pushdown_inexact_unsorted/q6.sql b/benchmarks/queries/sort_pushdown_inexact_unsorted/q6.sql new file mode 100644 index 0000000000000..8366e96969195 --- /dev/null +++ b/benchmarks/queries/sort_pushdown_inexact_unsorted/q6.sql @@ -0,0 +1,5 @@ +-- Unsupported path: wide projection + DESC LIMIT. +SELECT * +FROM lineitem +ORDER BY l_orderkey DESC +LIMIT 100 diff --git a/benchmarks/sql_benchmarks/README.md b/benchmarks/sql_benchmarks/README.md new file mode 100644 index 0000000000000..e8899641c024b --- /dev/null +++ b/benchmarks/sql_benchmarks/README.md @@ -0,0 +1,355 @@ + + +# SQL Benchmarks + +This directory contains a collection of benchmarks each driven by a simple '.benchmark' text file and sql queries +that exercise the DataFusion execution engine against a variety of benchmark suites. The sql benchmark framework +is intentionally simple so that benchmarks and queries can be added or modified without touching the core +engine or requiring recompilation. + +The sql benchmarks are organized in sub‑directories that correspond to the benchmark suites that are commonly used +in the community: + +| Benchmark Suite | Description | +|-----------------------|--------------------------------------------------------------------| +| `clickbench` | ClickBench benchmark | +| `clickbench extended` | 12 additional, more complex queries against the Clickbench dataset | +| `clickbench_sorted` | ClickBench benchmark using a pre-sorted hits file. | +| `h2o` | The `h2o` benchmark | +| `hj` | Hash join benchmark | +| `imdb` | IMDb benchmark | +| `nlj` | Nested‑loop join benchmark | +| `smj` | Sort‑merge join benchmark | +| `sort tpch` | Sorting benchmarks against the TPC-H lineitem table | +| `taxi` | NYC taxi dataset benchmark | +| `tpcds` | TPC‑DS queries | +| `tpch` | TPC‑H queries | + +# Running Benchmarks + +The easiest way to run a benchmark is to use the `bench.sh` shell script (up one level from this document) +as it takes care of configuring any required environment variables and can populate any required data files. +However, it is possible to directly run a sql benchmark using the `cargo bench` command. For example: + +```shell +BENCH_NAME=tpch cargo bench --bench sql +``` + +# Benchmark configuration + +Sql benchmarks are configured via environment variables. Cargo's bench command and +[criterion](https://github.com/criterion-rs/criterion.rs) (the underlying benchmark framework) have an unfortunate +limitation in that custom command arguments cannot be passed into a benchmark. The alternative is to use environment +variables to pass in arguments which is what is used here. + +The SQL benchmarking tool uses the following environment variables: + +| Environment Variable | Description | +|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| BENCH_NAME | The name of the benchmark suite to run. For example 'imdb'. This should correspond to a directory name in the `sql_benchmarks` directory. | +| BENCH_SUBGROUP | The subgroup with the benchmark suite to run. For example 'window' to run the window subgroup of the h2o benchmark. | +| BENCH_QUERY | A query number to run. | +| BENCH_PERSIST_RESULTS | true/false to persist benchmark results. Results will be persisted in csv format so be cognizant of the size of the results. | +| BENCH_VALIDATE | true/false to validate benchmark results against persisted results or result_query's. If both `BENCH_PERSIST_RESULTS` and `BENCH_VALIDATE` are true, persist mode runs and validation is skipped. | +| SIMULATE_LATENCY | Simulate object store latency to mimic remote storage (e.g. S3). Adds random latency in the range 20-200ms to each object store operation. | +| MEM_POOL_TYPE | The memory pool type to use, should be one of "fair" or "greedy". | +| MEMORY_LIMIT | Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query if there's any, otherwise run with no memory limit. | | + +Example – Run the H2O window benchmarks on the 'small' sized CSV data files: + +``` bash +BENCH_NAME=h2o BENCH_SUBGROUP=window H2O_BENCH_SIZE=small H20_FILE_TYPE=csv cargo bench --bench sql +``` + +Some benchmarks use custom environment variables as outlined below: + +| Name | Description | Default value | +|------------------------------|--------------------------------------------------------------------------------------------------------------------------|---------------| +| BENCH_SIZE | Used in the tpch, sort-tpch and tpcds benchmarks. The size corresponds to the scale factor. | `1` | +| TPCH_FILE_TYPE | Used in the tpch benchmark to specify which file type to query against. The valid options are `csv`, `parquet` and `mem` | `parquet` | +| H2O_FILE_TYPE | Used in the h2o benchmark to specify which file type to query against. The valid options are `csv` and `parquet` | `csv` | +| CLICKBENCH_TYPE | The type of partitioning for the clickbench benchmark. Valid options are `single` and `partitioned` | `single` | +| H2O_BENCH_SIZE | Used in the h2o benchmark. The valid options are `small`, `medium` and `big` | `small` | +| PREFER_HASH_JOIN | Control datafusion's config option `datafusion.optimizer.prefer_hash_join` | true | +| HASH_JOIN_BUFFERING_CAPACITY | Control datafusion's config option `datafusion.execution.hash_join_buffering_capacity` | 0 | +| BENCH_SORTED | Used in the sort_tpch benchmark to indicate whether the lineitem table should be sorted. | false | +| SORTED_BY | Used in the clickbench_sorted benchmark to indicate the column to sort by. | `EventTime` | +| SORTED_ORDER | Used in the clickbench_sorted benchmark to indicate the sort order of the column. | `ASC` | + +## How it works + +SQL benchmarks are run via cargo's bench command using [criterion](https://docs.rs/criterion/latest/criterion/) +for running and gathering statistics of each sql being benchmarked. + +Each individual benchmark is represented by a `.benchmark` file that contains a number of directives instructing +the tool on how to load data, run initializations, run assertions, run the benchmark, optionally persist and +validate results, and finally run any cleanup if required. + +Variables are supported in two forms: + +* string substitution based on environment variables (with default values if unset): \${ENV_VAR} and + \${ENV_VAR:-default}. +* if / else based on whether an environment variable is true or not + (\${ENV_VAR:-default|true value|false value}). In this form only the value `true` (case-insensitive) selects the + true branch; any other set value selects the false branch. If ENV_VAR is unset, the valud of `default` is used to +* select the branch. + +Comments in files are supported with lines starting with # or --. + +Many if not most of the benchmarks are set up using templates to reduce duplication across the .benchmark files. For +example here is one of the benchmark files for the h2o benchmark suite: + +``` +subgroup groupby + +template sql_benchmarks/h2o/h2o.benchmark.template +QUERY_NUMBER=1 +QUERY_NUMBER_PADDED=01 +``` + +The template directive above defines the subgroup the benchmark is part of, sets two variables (`QUERY_NUMBER` and +`QUERY_NUMBER_PADDED`) and points to a file containing more directives that are shared across the benchmark suite. + +``` +load sql_benchmarks/h2o/init/load_${BENCH_SUBGROUP:-groupby}_${BENCH_SIZE:-small}_${BENCH_FILE_TYPE:-csv}.sql + +name Q${QUERY_NUMBER_PADDED} +group h2o + +run sql_benchmarks/h2o/queries/${BENCH_SUBGROUP:-groupby}/q${QUERY_NUMBER_PADDED}.sql + +result sql_benchmarks/h2o/results/${BENCH_SUBGROUP:-groupby}/${BENCH_SIZE:-small}/q${QUERY_NUMBER_PADDED}.csv +``` + +The above showcases the use of defaults for variables: `${NAME:-default}` + +# Directives + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
DirectiveDescription
name + +The name of the benchmark. This will be used as part of the display name used by criterion.

Example:
+
name Q${QUERY_NUMBER_PADDED}
+ +The `name` directive also makes the value available to benchmark-file replacements as `BENCH_NAME`. This is separate +from the `BENCH_NAME` environment variable used to select which benchmark group to run. + +
group + +The group name of the benchmark used for grouping benchmarks together.

Example:
+
group imdb
+ +
subgroup + +The sub group name of the benchmark used for filtering to a specific sub group.

Example:
+
subgroup window
+ +
load + +The load directive called during initialization of the benchmark. If a path to a file is provided on the same +line as the load directive that path will be parsed and any sql statements in that file will be executed during +initialization. If no path is specified the next line is required to be the sql statement to execute.

The +load directive (including any following sql statement) must be followed by a blank line.

Example:
+
load sql_benchmarks/h2o/init/load_${BENCH_SUBGROUP:-groupby}_${BENCH_SIZE:-small}_${BENCH_FILE_TYPE:-csv}.sql
+or +
+load
+CREATE TABLE test AS (SELECT value as key FROM range(1000000) ORDER BY value); +
+ +
init + +The init directive is called after the load directive prior to benchmark execution. If a path to a file is +provided on the same line as the init directive that path will be parsed and any sql statements in that file will be +executed during the benchmark initialization. If no path is specified the next line is required to be the sql statement +to execute.

The init directive (including any following sql statement) must be followed by a blank +line.

Example:
+
+init
+set datafusion.execution.parquet.binary_as_string = true; +
+ +
run + +The run directive called during execution of the benchmark. If a path to a file is provided on the same line as +the run directive that path will be parsed and any sql statements in that file will be executed during the benchmark +run. If no path is specified the next line is required to be the sql statement to execute.

Multiple +statements are allowed within a single run directive, however a benchmark file may contain only one run directive. When +running with `BENCH_PERSIST_RESULTS` or `BENCH_VALIDATE`, only the last `SELECT` or `WITH` statement from that run +directive will be used for comparison.

The run directive (including any following sql statement) must be +followed by a blank line.

Example:
+
run sql_benchmarks/imdb/queries/${QUERY_NUMBER_PADDED}.sql
+ +
cleanup + +The cleanup directive is called after all other directives and can be used to cleanup after the benchmark - +e.g. to drop tables. If a path to a file is provided on the same line as the cleanup directive that path will be parsed +and any sql statements in that file will be executed during cleanup. If no path is specified the next line is +required to be the sql statement to execute.

The cleanup directive (including any following sql statement) +must be followed by a blank line.

Example:
+
+cleanup
+DROP TABLE test; +
+ +
expect_plan + +The expect_plan directive will check the physical plan for the string provided on the same line. This +can be used to validate that a particular join was used.

Example:
+
expect_plan NestedLoopJoinExec
+ +
assert + +The assert directive is run between the init and run directives and can be used to validate system state correctness +prior to running the benchmark sql. The format is +
+assert II
+SELECT name, value = 3 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency');
+----
+datafusion.execution.planning_concurrency true
+datafusion.execution.target_partitions true
+
+ +The number of I's corresponds to the number of columns in the result. The expected results can be either tab delimited +or pipe delimited. + +
result_query + +The result_query directive is run during the verify phase and can be used to verify a different set of results than any +that might come from queries executed from the `run` directive. The format is the same as the `assert` directive +above.

Example: +
+result_query III
+SELECT COUNT(DISTINCT id2), SUM(r4), COUNT(*) FROM answer;
+----
+123 345 45 +
+ + +Note that the results of the run query are not automatically stored into a table in datafusion. If you want to +verify a result from queries executed from the `run` directive those queries will have to be saved to a table directly +using `CREATE TABLE AS (..)` or similar. + +
result + +The result directive declares the expected result file used during verification. A path to a file is required on the +same line as the result directive. The file is parsed only during verification, and must be a pipe-delimited CSV file +with a header row. During verification, these expected rows are compared with the rows produced by the last saved +`SELECT` or `WITH` statement from the `run` directive.

Example:
+
+result sql_benchmarks/imdb/results/${QUERY_NUMBER_PADDED}.csv +
+ +
template + +The template directive allows for inclusion of another file in a benchmark file. A path to a file is +required on the same line as the template directive which will be parsed as a benchmark file. Parameters can be passed +to the template file using the format `KEY=value`, one per line after the template directive followed by a blank line. +

Example:
+
+template sql_benchmarks/smj/smj.benchmark.template
+QUERY_NUMBER=1
+QUERY_NUMBER_PADDED=01 +
+ +
includeThe include directive is similar to the template directive except that it does not support parameters.
echo + +The echo directive allows for echoing a string to stdout during the execution of the benchmark and may be useful for +debugging.

Example:
+
+echo The value for batch size is ${BATCH_SIZE:-8192} +
+ +
+ +# Extending an existing benchmark suite + +If you want to add a new query: + +* Create a new qXX.sql in the corresponding queries folder of the benchmark. +* Add a new qXX.benchmark that references the appropriate template (clickbench.benchmark.template, + h2o.benchmark.template, + etc.). +* (Optional) Add a new entry to the suite’s load script if the data set is different. +* (Optional) Manually create a result csv to be compared against benchmark results during verification. + +# Adding a new benchmark suite + +* Create a new directory named for the new benchmark suite. +* Within there create a `.benchmark` for each individual benchmark. +* Populate the benchmark with directives as described above. Use the other benchmarks as examples for standardization. +* No rust files need to be updated to run the new benchmark suite. diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q01.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q01.benchmark new file mode 100644 index 0000000000000..d490927df326f --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q01.benchmark @@ -0,0 +1,34 @@ +name Q01 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order + from lineitem + where l_shipdate <= date '1998-12-01' - interval '90' day + group by l_returnflag, + l_linestatus + order by l_returnflag, + l_linestatus; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q01.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q02.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q02.benchmark new file mode 100644 index 0000000000000..6f365248b4998 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q02.benchmark @@ -0,0 +1,62 @@ +name Q02 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' +) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q02.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q03.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q03.benchmark new file mode 100644 index 0000000000000..cb16b10c2bb5a --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q03.benchmark @@ -0,0 +1,41 @@ +name Q03 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q03.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q04.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q04.benchmark new file mode 100644 index 0000000000000..f2e6f9a558416 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q04.benchmark @@ -0,0 +1,39 @@ +name Q04 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q04.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q05.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q05.benchmark new file mode 100644 index 0000000000000..9b5fbda63b4cb --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q05.benchmark @@ -0,0 +1,37 @@ +name Q05 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from customer, + orders, + lineitem, + supplier, + nation, + region +where c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by n_name +order by revenue desc; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q05.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q06.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q06.benchmark new file mode 100644 index 0000000000000..761875cbf3558 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q06.benchmark @@ -0,0 +1,25 @@ +name Q06 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select sum(l_extendedprice * l_discount) as revenue +from lineitem +where l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; + + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q06.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q07.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q07.benchmark new file mode 100644 index 0000000000000..30c4c520de823 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q07.benchmark @@ -0,0 +1,57 @@ +name Q07 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q07.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q08.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q08.benchmark new file mode 100644 index 0000000000000..86caded4b7f1e --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q08.benchmark @@ -0,0 +1,55 @@ +name Q08 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q08.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q09.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q09.benchmark new file mode 100644 index 0000000000000..3302cf6f0ba81 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q09.benchmark @@ -0,0 +1,50 @@ +name Q09 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q09.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q10.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q10.benchmark new file mode 100644 index 0000000000000..4ef08e3fd2074 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q10.benchmark @@ -0,0 +1,50 @@ +name Q10 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q10.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q11.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q11.benchmark new file mode 100644 index 0000000000000..833799a39d756 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q11.benchmark @@ -0,0 +1,45 @@ +name Q11 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * (0.0001/${BENCH_SIZE:-1}) + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q11.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q12.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q12.benchmark new file mode 100644 index 0000000000000..37aee848c962b --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q12.benchmark @@ -0,0 +1,43 @@ +name Q12 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from lineitem + join + orders + on + l_orderkey = o_orderkey +where l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by l_shipmode +order by l_shipmode; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q12.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q13.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q13.benchmark new file mode 100644 index 0000000000000..dfb09853d0987 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q13.benchmark @@ -0,0 +1,38 @@ +name Q13 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q13.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q14.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q14.benchmark new file mode 100644 index 0000000000000..b48d95043fdcb --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q14.benchmark @@ -0,0 +1,28 @@ +name Q14 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q14.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q15.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q15.benchmark new file mode 100644 index 0000000000000..0f50fc499d0b4 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q15.benchmark @@ -0,0 +1,49 @@ +name Q15 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey; +drop view revenue0; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q15.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q16.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q16.benchmark new file mode 100644 index 0000000000000..3fa6c68e29985 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q16.benchmark @@ -0,0 +1,48 @@ +name Q16 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' +) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q16.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q17.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q17.benchmark new file mode 100644 index 0000000000000..a31c837d1e164 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q17.benchmark @@ -0,0 +1,35 @@ +name Q17 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey +); + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q17.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q18.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q18.benchmark new file mode 100644 index 0000000000000..149b0efd01c99 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q18.benchmark @@ -0,0 +1,51 @@ +name Q18 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q18.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q19.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q19.benchmark new file mode 100644 index 0000000000000..f93ad6cb73143 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q19.benchmark @@ -0,0 +1,53 @@ +name Q19 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q19.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q20.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q20.benchmark new file mode 100644 index 0000000000000..123386055b1ba --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q20.benchmark @@ -0,0 +1,55 @@ +name Q20 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q20.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q21.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q21.benchmark new file mode 100644 index 0000000000000..24d754a4cbd15 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q21.benchmark @@ -0,0 +1,58 @@ +name Q21 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q21.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/benchmarks/q22.benchmark b/benchmarks/sql_benchmarks/tpch/benchmarks/q22.benchmark new file mode 100644 index 0000000000000..7ef6a78496c32 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/benchmarks/q22.benchmark @@ -0,0 +1,55 @@ +name Q22 +group tpch +subgroup sf${BENCH_SIZE:-1} + +init sql_benchmarks/tpch/init/set_config.sql + +load sql_benchmarks/tpch/init/load_${TPCH_FILE_TYPE:-parquet}.sql + +assert I +SELECT COUNT(*) > 0 from lineitem; +---- +true + +run +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; + +result sql_benchmarks/tpch/results/sf${BENCH_SIZE:-1}/q22.csv + +cleanup sql_benchmarks/tpch/init/cleanup.sql diff --git a/benchmarks/sql_benchmarks/tpch/init/cleanup.sql b/benchmarks/sql_benchmarks/tpch/init/cleanup.sql new file mode 100644 index 0000000000000..c8fb66a6a57e8 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/init/cleanup.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS nation; + +DROP TABLE IF EXISTS region; + +DROP TABLE IF EXISTS supplier; + +DROP TABLE IF EXISTS customer; + +DROP TABLE IF EXISTS part; + +DROP TABLE IF EXISTS partsupp; + +DROP TABLE IF EXISTS orders; + +DROP TABLE IF EXISTS lineitem; diff --git a/benchmarks/sql_benchmarks/tpch/init/load_csv.sql b/benchmarks/sql_benchmarks/tpch/init/load_csv.sql new file mode 100644 index 0000000000000..f9a9b2e988e24 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/init/load_csv.sql @@ -0,0 +1,99 @@ +CREATE EXTERNAL TABLE nation +( + n_nationkey INT, + n_name CHAR(25), + n_regionkey INT, + n_comment VARCHAR(152), + PRIMARY KEY (n_nationkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/nation/nation.1.csv'; + +CREATE EXTERNAL TABLE region +( + r_regionkey INT, + r_name CHAR(25), + r_comment VARCHAR(152), + PRIMARY KEY (r_regionkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/region/region.1.csv'; + +CREATE EXTERNAL TABLE supplier +( + s_suppkey INT, + s_name CHAR(25), + s_address VARCHAR(40), + s_nationkey INT, + s_phone CHAR(15), + s_acctbal DECIMAL(15, 2), + s_comment VARCHAR(101), + PRIMARY KEY (s_suppkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/supplier/supplier.1.csv'; + +CREATE EXTERNAL TABLE customer +( + c_custkey INT, + c_name VARCHAR(25), + c_address VARCHAR(40), + c_nationkey INT, + c_phone CHAR(15), + c_acctbal DECIMAL(15, 2), + c_mktsegment CHAR(10), + c_comment VARCHAR(117), + PRIMARY KEY (c_custkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/customer/customer.1.csv'; + +CREATE EXTERNAL TABLE part +( + p_partkey INT, + p_name VARCHAR(55), + p_mfgr CHAR(25), + p_brand CHAR(10), + p_type VARCHAR(25), + p_size INT, + p_container CHAR(10), + p_retailprice DECIMAL(15, 2), + p_comment VARCHAR(23), + PRIMARY KEY (p_partkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/part/part.1.csv'; + +CREATE EXTERNAL TABLE partsupp +( + ps_partkey INT, + ps_suppkey INT, + ps_availqty INT, + ps_supplycost DECIMAL(15, 2), + ps_comment VARCHAR(199), + PRIMARY KEY (ps_partkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/partsupp/partsupp.1.csv'; + +CREATE EXTERNAL TABLE orders +( + o_orderkey INT, + o_custkey INT, + o_orderstatus CHAR(1), + o_totalprice DECIMAL(15, 2), + o_orderdate DATE, + o_orderpriority CHAR(15), + o_clerk CHAR(15), + o_shippriority INT, + o_comment VARCHAR(79), + PRIMARY KEY (o_orderkey) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/orders/orders.1.csv'; + +CREATE EXTERNAL TABLE lineitem +( + l_orderkey INT, + l_partkey INT, + l_suppkey INT, + l_linenumber INT, + l_quantity DECIMAL(15, 2), + l_extendedprice DECIMAL(15, 2), + l_discount DECIMAL(15, 2), + l_tax DECIMAL(15, 2), + l_returnflag CHAR(1), + l_linestatus CHAR(1), + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct CHAR(25), + l_shipmode CHAR(10), + l_comment VARCHAR(44) +) STORED AS CSV LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/csv/lineitem/lineitem.1.csv'; \ No newline at end of file diff --git a/benchmarks/sql_benchmarks/tpch/init/load_mem.sql b/benchmarks/sql_benchmarks/tpch/init/load_mem.sql new file mode 100644 index 0000000000000..57d12c22f0c52 --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/init/load_mem.sql @@ -0,0 +1,31 @@ +CREATE EXTERNAL TABLE nation_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/nation/nation.1.parquet'; + +CREATE EXTERNAL TABLE region_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/region/region.1.parquet'; + +CREATE EXTERNAL TABLE supplier_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/supplier/supplier.1.parquet'; + +CREATE EXTERNAL TABLE customer_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/customer/customer.1.parquet'; + +CREATE EXTERNAL TABLE part_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/part/part.1.parquet'; + +CREATE EXTERNAL TABLE partsupp_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/partsupp/partsupp.1.parquet'; + +CREATE EXTERNAL TABLE orders_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/orders/orders.1.parquet'; + +CREATE EXTERNAL TABLE lineitem_raw STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/lineitem/lineitem.1.parquet'; + +CREATE TABLE nation as SELECT * FROM nation_raw; + +CREATE TABLE region as SELECT * FROM region_raw; + +CREATE TABLE supplier as SELECT * FROM supplier_raw; + +CREATE TABLE customer as SELECT * FROM customer_raw; + +CREATE TABLE part as SELECT * FROM part_raw; + +CREATE TABLE partsupp as SELECT * FROM partsupp_raw; + +CREATE TABLE orders as SELECT * FROM orders_raw; + +CREATE TABLE lineitem as SELECT * FROM lineitem_raw; diff --git a/benchmarks/sql_benchmarks/tpch/init/load_parquet.sql b/benchmarks/sql_benchmarks/tpch/init/load_parquet.sql new file mode 100644 index 0000000000000..172a03d82a2cf --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/init/load_parquet.sql @@ -0,0 +1,15 @@ +CREATE EXTERNAL TABLE nation STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/nation/nation.1.parquet'; + +CREATE EXTERNAL TABLE region STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/region/region.1.parquet'; + +CREATE EXTERNAL TABLE supplier STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/supplier/supplier.1.parquet'; + +CREATE EXTERNAL TABLE customer STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/customer/customer.1.parquet'; + +CREATE EXTERNAL TABLE part STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/part/part.1.parquet'; + +CREATE EXTERNAL TABLE partsupp STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/partsupp/partsupp.1.parquet'; + +CREATE EXTERNAL TABLE orders STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/orders/orders.1.parquet'; + +CREATE EXTERNAL TABLE lineitem STORED AS PARQUET LOCATION 'data/tpch_sf${BENCH_SIZE:-1}/lineitem/lineitem.1.parquet'; \ No newline at end of file diff --git a/benchmarks/sql_benchmarks/tpch/init/set_config.sql b/benchmarks/sql_benchmarks/tpch/init/set_config.sql new file mode 100644 index 0000000000000..00457e2bca1ef --- /dev/null +++ b/benchmarks/sql_benchmarks/tpch/init/set_config.sql @@ -0,0 +1,3 @@ +set datafusion.optimizer.prefer_hash_join=${PREFER_HASH_JOIN:-true}; + +set datafusion.execution.hash_join_buffering_capacity=${HASH_JOIN_BUFFERING_CAPACITY:-0}; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 816cae0e38555..3b1f54291e75c 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -18,7 +18,7 @@ //! DataFusion benchmark runner use datafusion::error::Result; -use structopt::StructOpt; +use clap::{Parser, Subcommand}; #[cfg(all(feature = "snmalloc", feature = "mimalloc"))] compile_error!( @@ -34,11 +34,18 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, hj, imdb, nlj, sort_tpch, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, smj, sort_pushdown, sort_tpch, tpcds, + tpch, }; -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] +#[derive(Debug, Parser)] +#[command(about = "benchmark command")] +struct Cli { + #[command(subcommand)] + command: Options, +} + +#[derive(Debug, Subcommand)] enum Options { Cancellation(cancellation::RunOpt), Clickbench(clickbench::RunOpt), @@ -46,9 +53,11 @@ enum Options { HJ(hj::RunOpt), Imdb(imdb::RunOpt), Nlj(nlj::RunOpt), + Smj(smj::RunOpt), + SortPushdown(sort_pushdown::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), - TpchConvert(tpch::ConvertOpt), + Tpcds(tpcds::RunOpt), } // Main benchmark runner entrypoint @@ -56,15 +65,18 @@ enum Options { pub async fn main() -> Result<()> { env_logger::init(); - match Options::from_args() { + let cli = Cli::parse(); + match cli.command { Options::Cancellation(opt) => opt.run().await, Options::Clickbench(opt) => opt.run().await, Options::H2o(opt) => opt.run().await, Options::HJ(opt) => opt.run().await, Options::Imdb(opt) => Box::pin(opt.run()).await, Options::Nlj(opt) => opt.run().await, + Options::Smj(opt) => opt.run().await, + Options::SortPushdown(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, Options::Tpch(opt) => Box::pin(opt.run()).await, - Options::TpchConvert(opt) => opt.run().await, + Options::Tpcds(opt) => Box::pin(opt.run()).await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 46b6cc9a80b24..ee604ec7365a1 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -17,13 +17,13 @@ //! external_aggr binary entrypoint +use clap::{Args, Parser, Subcommand}; use datafusion::execution::memory_pool::GreedyMemoryPool; use datafusion::execution::memory_pool::MemoryPool; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::sync::LazyLock; -use structopt::StructOpt; use arrow::record_batch::RecordBatch; use arrow::util::pretty; @@ -33,47 +33,53 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; use datafusion::execution::memory_pool::FairSpillPool; -use datafusion::execution::memory_pool::{human_readable_size, units}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, exec_err}; +use datafusion_common::{human_readable_size, units}; -#[derive(Debug, StructOpt)] -#[structopt( +#[derive(Debug, Parser)] +#[command( name = "datafusion-external-aggregation", about = "DataFusion external aggregation benchmark" )] +struct Cli { + #[command(subcommand)] + command: ExternalAggrOpt, +} + +#[derive(Debug, Subcommand)] enum ExternalAggrOpt { Benchmark(ExternalAggrConfig), } -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] struct ExternalAggrConfig { /// Query number. If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files (lineitem). Only parquet format is supported - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to JSON benchmark result to be compare using `compare.py` - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, } @@ -338,7 +344,8 @@ impl ExternalAggrConfig { pub async fn main() -> Result<()> { env_logger::init(); - match ExternalAggrOpt::from_args() { + let cli = Cli::parse(); + match cli.command { ExternalAggrOpt::Benchmark(opt) => opt.run().await?, } diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 5ce99928df662..e86735f87b8f1 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -17,9 +17,9 @@ //! IMDB binary entrypoint +use clap::{Parser, Subcommand}; use datafusion::error::Result; use datafusion_benchmarks::imdb; -use structopt::StructOpt; #[cfg(all(feature = "snmalloc", feature = "mimalloc"))] compile_error!( @@ -34,24 +34,30 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] -enum BenchmarkSubCommandOpt { - #[structopt(name = "datafusion")] - DataFusionBenchmark(imdb::RunOpt), +#[derive(Debug, Parser)] +#[command(name = "IMDB", about = "IMDB Dataset Processing.")] +struct Cli { + #[command(subcommand)] + command: ImdbOpt, } -#[derive(Debug, StructOpt)] -#[structopt(name = "IMDB", about = "IMDB Dataset Processing.")] +#[derive(Debug, Subcommand)] enum ImdbOpt { + #[command(subcommand)] Benchmark(BenchmarkSubCommandOpt), Convert(imdb::ConvertOpt), } +#[derive(Debug, Subcommand)] +enum BenchmarkSubCommandOpt { + #[command(name = "datafusion")] + DataFusionBenchmark(imdb::RunOpt), +} + #[tokio::main] pub async fn main() -> Result<()> { env_logger::init(); - match ImdbOpt::from_args() { + match Cli::parse().command { ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { Box::pin(opt.run()).await } diff --git a/benchmarks/src/bin/mem_profile.rs b/benchmarks/src/bin/mem_profile.rs index 16fc3871bec86..41a0baecbba86 100644 --- a/benchmarks/src/bin/mem_profile.rs +++ b/benchmarks/src/bin/mem_profile.rs @@ -16,6 +16,7 @@ // under the License. //! mem_profile binary entrypoint +use clap::{Parser, Subcommand}; use datafusion::error::Result; use std::{ env, @@ -23,7 +24,6 @@ use std::{ path::Path, process::{Command, Stdio}, }; -use structopt::StructOpt; use datafusion_benchmarks::{ clickbench, @@ -31,19 +31,19 @@ use datafusion_benchmarks::{ imdb, sort_tpch, tpch, }; -#[derive(Debug, StructOpt)] -#[structopt(name = "Memory Profiling Utility")] -struct MemProfileOpt { +#[derive(Debug, Parser)] +#[command(name = "Memory Profiling Utility")] +struct Cli { /// Cargo profile to use in dfbench (e.g. release, release-nonlto) - #[structopt(long, default_value = "release")] + #[arg(long, default_value = "release")] bench_profile: String, - #[structopt(subcommand)] + #[command(subcommand)] command: Options, } -#[derive(Debug, StructOpt)] -#[structopt(about = "Benchmark command")] +#[derive(Debug, Subcommand)] +#[command(about = "Benchmark command")] enum Options { Clickbench(clickbench::RunOpt), H2o(h2o::RunOpt), @@ -55,9 +55,9 @@ enum Options { #[tokio::main] pub async fn main() -> Result<()> { // 1. Parse args and check which benchmarks should be run - let mem_profile_opt = MemProfileOpt::from_args(); - let profile = mem_profile_opt.bench_profile; - let query_range = match mem_profile_opt.command { + let cli = Cli::parse(); + let profile = cli.bench_profile; + let query_range = match cli.command { Options::Clickbench(opt) => { let entries = std::fs::read_dir(&opt.queries_path)? .filter_map(Result::ok) @@ -199,21 +199,18 @@ fn run_query(args: &[String], results: &mut Vec) -> Result<()> { // Look for lines that contain execution time / memory stats while let Some(line) = iter.next() { - if let Some((query, duration_ms)) = parse_query_time(line) { - if let Some(next_line) = iter.peek() { - if let Some((peak_rss, peak_commit, page_faults)) = - parse_vm_line(next_line) - { - results.push(QueryResult { - query, - duration_ms, - peak_rss, - peak_commit, - page_faults, - }); - break; - } - } + if let Some((query, duration_ms)) = parse_query_time(line) + && let Some(next_line) = iter.peek() + && let Some((peak_rss, peak_commit, page_faults)) = parse_vm_line(next_line) + { + results.push(QueryResult { + query, + duration_ms, + peak_rss, + peak_commit, + page_faults, + }); + break; } } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs deleted file mode 100644 index ca2bb8e57c0ec..0000000000000 --- a/benchmarks/src/bin/tpch.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! tpch binary only entrypoint - -use datafusion::error::Result; -use datafusion_benchmarks::tpch; -use structopt::StructOpt; - -#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] -compile_error!( - "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" -); - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[cfg(feature = "mimalloc")] -#[global_allocator] -static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; - -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] -enum BenchmarkSubCommandOpt { - #[structopt(name = "datafusion")] - DataFusionBenchmark(tpch::RunOpt), -} - -#[derive(Debug, StructOpt)] -#[structopt(name = "TPC-H", about = "TPC-H Benchmarks.")] -enum TpchOpt { - Benchmark(BenchmarkSubCommandOpt), - Convert(tpch::ConvertOpt), -} - -/// 'tpch' entry point, with tortured command line arguments. Please -/// use `dbbench` instead. -/// -/// Note: this is kept to be backwards compatible with the benchmark names prior to -/// -#[tokio::main] -async fn main() -> Result<()> { - env_logger::init(); - match TpchOpt::from_args() { - TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - Box::pin(opt.run()).await - } - TpchOpt::Convert(opt) => opt.run().await, - } -} diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index fcf03fbc54550..d3da1b0e83623 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -24,24 +24,24 @@ use crate::util::{BenchmarkRun, CommonOpt}; use arrow::array::Array; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; +use clap::Args; use datafusion::common::{Result, ScalarValue}; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; -use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::execution::TaskContext; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::prelude::*; use datafusion_common::instant::Instant; use futures::TryStreamExt; use object_store::ObjectStore; -use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; +use parquet::arrow::async_writer::ParquetObjectWriter; +use rand::Rng; use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; -use rand::Rng; -use structopt::StructOpt; use tokio::runtime::Runtime; use tokio_util::sync::CancellationToken; @@ -57,31 +57,31 @@ use tokio_util::sync::CancellationToken; /// The query is an anonymized version of a real-world query, and the /// test starts the query then cancels it and reports how long it takes /// for the runtime to fully exit. -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to folder where data will be generated - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Number of files to generate - #[structopt(long = "num-files", default_value = "7")] + #[arg(long = "num-files", default_value = "7")] num_files: usize, /// Number of rows per file to generate - #[structopt(long = "num-rows-per-file", default_value = "5000000")] + #[arg(long = "num-rows-per-file", default_value = "5000000")] num_rows_per_file: usize, /// How long to wait, in milliseconds, before attempting to cancel - #[structopt(long = "wait-time", default_value = "100")] + #[arg(long = "wait-time", default_value = "100")] wait_time: u64, } diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index a550503390c54..70aaeb7d2d192 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -19,7 +19,8 @@ use std::fs; use std::io::ErrorKind; use std::path::{Path, PathBuf}; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; +use clap::Args; use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{ error::{DataFusionError, Result}, @@ -27,7 +28,16 @@ use datafusion::{ }; use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; -use structopt::StructOpt; + +/// SQL to create the hits view with proper EventDate casting. +/// +/// ClickBench stores EventDate as UInt16 (days since 1970-01-01) for +/// storage efficiency (2 bytes vs 4-8 bytes for date types). +/// This view transforms it to SQL DATE type for query compatibility. +const HITS_VIEW_DDL: &str = r#"CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw"#; /// Driver program to run the ClickBench benchmark /// @@ -37,11 +47,11 @@ use structopt::StructOpt; /// /// [1]: https://github.com/ClickHouse/ClickBench /// [2]: https://github.com/ClickHouse/ClickBench/tree/main/datafusion -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number (between 0 and 42). If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] pub query: Option, /// If specified, enables Parquet Filter Pushdown. @@ -49,35 +59,54 @@ pub struct RunOpt { /// Specifically, it enables: /// * `pushdown_filters = true` /// * `reorder_filters = true` - #[structopt(long = "pushdown")] + #[arg(long = "pushdown")] pushdown: bool, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to hits.parquet (single file) or `hits_partitioned` /// (partitioned, 100 files) - #[structopt( - parse(from_os_str), - short = "p", + #[arg( + short = 'p', long = "path", default_value = "benchmarks/data/hits.parquet" )] path: PathBuf, /// Path to queries directory - #[structopt( - parse(from_os_str), - short = "r", + #[arg( + short = 'r', long = "queries-path", default_value = "benchmarks/queries/clickbench/queries" )] pub queries_path: PathBuf, /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, + + /// Column name that the data is sorted by (e.g., "EventTime") + /// If specified, DataFusion will be informed that the data has this sort order + /// using CREATE EXTERNAL TABLE with WITH ORDER clause. + /// + /// Recommended to use with: -c datafusion.optimizer.prefer_existing_sort=true + /// This allows DataFusion to optimize away redundant sorts while maintaining + /// multi-core parallelism for other operations. + #[arg(long = "sorted-by")] + sorted_by: Option, + + /// Sort order: ASC or DESC (default: ASC) + #[arg(long = "sort-order", default_value = "ASC")] + sort_order: String, + + /// Configuration options in the format key=value + /// Can be specified multiple times. + /// + /// Example: -c datafusion.optimizer.prefer_existing_sort=true + #[arg(short = 'c', long = "config")] + config_options: Vec, } /// Get the SQL file path @@ -125,6 +154,39 @@ impl RunOpt { // configure parquet options let mut config = self.common.config()?; + + if self.sorted_by.is_some() { + println!("ℹ️ Data is registered with sort order"); + + let has_prefer_sort = self + .config_options + .iter() + .any(|opt| opt.contains("prefer_existing_sort=true")); + + if !has_prefer_sort { + println!( + "ℹ️ Consider using -c datafusion.optimizer.prefer_existing_sort=true" + ); + println!("ℹ️ to optimize queries while maintaining parallelism"); + } + } + + // Apply user-provided configuration options + for config_opt in &self.config_options { + let parts: Vec<&str> = config_opt.splitn(2, '=').collect(); + if parts.len() != 2 { + return Err(exec_datafusion_err!( + "Invalid config option format: '{}'. Expected 'key=value'", + config_opt + )); + } + let key = parts[0]; + let value = parts[1]; + + println!("Setting config: {key} = {value}"); + config = config.set_str(key, value); + } + { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns @@ -136,10 +198,18 @@ impl RunOpt { parquet_options.pushdown_filters = true; parquet_options.reorder_filters = true; } + + if self.sorted_by.is_some() { + // We should compare the dynamic topk optimization when data is sorted, so we make the + // assumption that filter pushdown is also enabled in this case. + parquet_options.pushdown_filters = true; + parquet_options.reorder_filters = true; + } } - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + self.register_hits(&ctx).await?; let mut benchmark_run = BenchmarkRun::new(); @@ -214,17 +284,68 @@ impl RunOpt { } /// Registers the `hits.parquet` as a table named `hits` + /// If sorted_by is specified, uses CREATE EXTERNAL TABLE with WITH ORDER async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { - let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); - ctx.register_parquet("hits", path, options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'hits' as {path}"), - Box::new(e), - ) - }) + + // If sorted_by is specified, use CREATE EXTERNAL TABLE with WITH ORDER + if let Some(ref sort_column) = self.sorted_by { + println!( + "Registering table with sort order: {} {}", + sort_column, self.sort_order + ); + + // Escape column name with double quotes + let escaped_column = if sort_column.contains('"') { + sort_column.clone() + } else { + format!("\"{sort_column}\"") + }; + + // Build CREATE EXTERNAL TABLE DDL with WITH ORDER clause + // Schema will be automatically inferred from the Parquet file + let create_table_sql = format!( + "CREATE EXTERNAL TABLE hits_raw \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER ({} {})", + path, + escaped_column, + self.sort_order.to_uppercase() + ); + + println!("Executing: {create_table_sql}"); + + // Execute the CREATE EXTERNAL TABLE statement + ctx.sql(&create_table_sql).await?.collect().await?; + } else { + // Original registration without sort order + let options = Default::default(); + ctx.register_parquet("hits_raw", path, options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'hits_raw' as {path}"), + Box::new(e), + ) + })?; + } + + // Create the hits view with EventDate transformation + Self::create_hits_view(ctx).await + } + + /// Creates the hits view with EventDate transformation from UInt16 to DATE. + /// + /// ClickBench encodes EventDate as UInt16 days since epoch (1970-01-01). + async fn create_hits_view(ctx: &SessionContext) -> Result<()> { + ctx.sql(HITS_VIEW_DDL).await?.collect().await.map_err(|e| { + DataFusionError::Context( + "Creating 'hits' view with EventDate transformation".to_string(), + Box::new(e), + ) + })?; + Ok(()) } fn iterations(&self) -> usize { diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index be74252031194..8b6e04932cb39 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -20,31 +20,30 @@ //! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) //! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, print_memory_stats}; +use clap::Args; use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::{ - exec_datafusion_err, instant::Instant, internal_err, DataFusionError, TableReference, + DataFusionError, TableReference, exec_datafusion_err, instant::Instant, internal_err, }; use std::path::{Path, PathBuf}; -use structopt::StructOpt; /// Run the H2O benchmark -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { - #[structopt(short, long)] + #[arg(short, long)] pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to queries.sql (single file) /// default value is the groupby.sql file in the h2o benchmark - #[structopt( - parse(from_os_str), - short = "r", + #[arg( + short = 'r', long = "queries-path", default_value = "benchmarks/queries/h2o/groupby.sql" )] @@ -53,9 +52,8 @@ pub struct RunOpt { /// Path to data file (parquet or csv) /// Default value is the G1_1e7_1e7_100_0.csv file in the h2o benchmark /// This is the small csv file with 10^7 rows - #[structopt( - parse(from_os_str), - short = "p", + #[arg( + short = 'p', long = "path", default_value = "benchmarks/data/h2o/G1_1e7_1e7_100_0.csv" )] @@ -64,15 +62,15 @@ pub struct RunOpt { /// Path to data files (parquet or csv), using , to separate the paths /// Default value is the small files for join x table, small table, medium table, big table files in the h2o benchmark /// This is the small csv file case - #[structopt( - short = "join-paths", + #[arg( + short = 'j', long = "join-paths", default_value = "benchmarks/data/h2o/J1_1e7_NA_0.csv,benchmarks/data/h2o/J1_1e7_1e1_0.csv,benchmarks/data/h2o/J1_1e7_1e4_0.csv,benchmarks/data/h2o/J1_1e7_1e7_NA.csv" )] join_paths: String, /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, } @@ -86,8 +84,8 @@ impl RunOpt { }; let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // Register tables depending on which h2o benchmark is being run // (groupby/join/window) diff --git a/benchmarks/src/hj.rs b/benchmarks/src/hj.rs index 505b322745485..301fe0d599cd6 100644 --- a/benchmarks/src/hj.rs +++ b/benchmarks/src/hj.rs @@ -16,11 +16,12 @@ // under the License. use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; use datafusion::physical_plan::execute_stream; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::instant::Instant; -use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; -use structopt::StructOpt; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; +use std::path::PathBuf; use futures::StreamExt; @@ -32,139 +33,276 @@ use futures::StreamExt; /// It uses simple equality predicates to ensure a hash join is selected. /// Where we vary selectivity, we do so with additional cheap predicates that /// do not change the join key (so the physical operator remains HashJoin). -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { - /// Query number (between 1 and 12). If not specified, runs all queries - #[structopt(short, long)] + /// Query number. If not specified, runs all queries + #[arg(short, long)] query: Option, /// Common options (iterations, batch size, target_partitions, etc.) - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, + /// Path to TPC-H SF10 data + #[arg(short = 'p', long = "path")] + path: Option, + /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, + #[arg(short = 'o', long = "output")] + output_path: Option, +} + +struct HashJoinQuery { + sql: &'static str, + density: f64, + prob_hit: f64, + build_size: &'static str, + probe_size: &'static str, } /// Inline SQL queries for Hash Join benchmarks -/// -/// Each query's comment includes: -/// - Left row count × Right row count -/// - Join predicate selectivity (approximate output fraction). -/// - Q11 and Q12 selectivity is relative to cartesian product while the others are -/// relative to probe side. -const HASH_QUERIES: &[&str] = &[ - // Q1: INNER 10 x 10K | LOW ~0.1% - // equality on key + cheap filter to downselect - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1(value) - JOIN range(10000) AS t2 - ON t1.value = t2.value; - "#, - // Q2: INNER 10 x 10K | LOW ~0.1% - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1 - JOIN range(10000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 5 = 0 - "#, - // Q3: INNER 10K x 10K | HIGH ~90% - r#" - SELECT t1.value, t2.value - FROM range(10000) AS t1 - JOIN range(10000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 10 <> 0 - "#, - // Q4: INNER 30 x 30K | LOW ~0.1% - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 5 = 0 - "#, - // Q5: INNER 10 x 200K | VERY LOW ~0.005% (small to large) - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 9000, 1000) AS t1 - JOIN range(200000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q6: INNER 200K x 10 | VERY LOW ~0.005% (large to small) - r#" - SELECT t1.value, t2.value - FROM range(200000) AS t1 - JOIN generate_series(0, 9000, 1000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q7: RIGHT OUTER 10 x 200K | LOW ~0.1% - // Outer join still uses HashJoin for equi-keys; the extra filter reduces matches - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 9000, 1000) AS t1 - RIGHT JOIN range(200000) AS t2 - ON t1.value = t2.value - WHERE t2.value % 1000 = 0 - "#, - // Q8: LEFT OUTER 200K x 10 | LOW ~0.1% - r#" - SELECT t1.value AS l, t2.value AS r - FROM range(200000) AS t1 - LEFT JOIN generate_series(0, 9000, 1000) AS t2 - ON t1.value = t2.value - WHERE t1.value % 1000 = 0 - "#, - // Q9: FULL OUTER 30 x 30K | LOW ~0.1% - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE COALESCE(t1.value, t2.value) % 1000 = 0 - "#, - // Q10: FULL OUTER 30 x 30K | HIGH ~90% - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON t1.value = t2.value - WHERE COALESCE(t1.value, t2.value) % 10 <> 0 - "#, - // Q11: INNER 30 x 30K | MEDIUM ~50% | cheap predicate on parity - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - INNER JOIN range(30000) AS t2 - ON (t1.value % 2) = (t2.value % 2) - "#, - // Q12: FULL OUTER 30 x 30K | MEDIUM ~50% | expression key - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON (t1.value % 2) = (t2.value % 2) - "#, - // Q13: INNER 30 x 30K | LOW 0.1% | modulo with adding values - r#" - SELECT t1.value, t2.value - FROM generate_series(0, 29000, 1000) AS t1 - INNER JOIN range(30000) AS t2 - ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 < 1) - "#, - // Q14: FULL OUTER 30 x 30K | ALL ~100% | modulo - r#" - SELECT t1.value AS l, t2.value AS r - FROM generate_series(0, 29000, 1000) AS t1 - FULL JOIN range(30000) AS t2 - ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 = 0) - "#, +const HASH_QUERIES: &[HashJoinQuery] = &[ + // Q1: Very Small Build Side (Dense) + // Build Side: nation (25 rows) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT n_nationkey FROM nation JOIN customer ON c_nationkey = n_nationkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q2: Very Small Build Side (Sparse, range < 1024) + // Build Side: nation (25 rows, range 961) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT c_nationkey * 40 as k + FROM customer + ) l + JOIN ( + SELECT n_nationkey * 40 as k FROM nation + ) s ON l.k = s.k"###, + density: 0.026, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q3: 100% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT s_suppkey FROM supplier JOIN lineitem ON s_suppkey = l_suppkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q4: 100% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE WHEN l_suppkey % 10 = 0 THEN l_suppkey ELSE l_suppkey + 1000000 END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey as k FROM supplier + ) s ON l.k = s.k"###, + density: 1.0, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q5: 75% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 4 / 3 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q6: 75% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 4 / 3 + WHEN l_suppkey % 10 < 9 THEN (l_suppkey * 4 / 3 / 4) * 4 + 3 + ELSE l_suppkey * 4 / 3 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q7: 50% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 2 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q8: 50% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 2 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 2 + 1 + ELSE l_suppkey * 2 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q9: 20% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 5 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q10: 20% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 5 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 5 + 1 + ELSE l_suppkey * 5 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q11: 10% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 10 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q12: 10% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 10 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 10 + 1 + ELSE l_suppkey * 10 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q13: 1% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 100 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q14: 1% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 100 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 100 + 1 + ELSE l_suppkey * 100 + 11000000 -- oob + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q15: 20% Density, 10% Hit rate, 20% Duplicates in Build Side + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN ((l_suppkey % 80000) + 1) * 25 / 4 + ELSE ((l_suppkey % 80000) + 1) * 25 / 4 + 1 + END as k + FROM lineitem + ) l + JOIN ( + SELECT CASE + WHEN s_suppkey <= 80000 THEN (s_suppkey * 25) / 4 + ELSE ((s_suppkey - 80000) * 25) / 4 + END as k + FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K_(20%_dups)", + probe_size: "60M", + }, ]; impl RunOpt { @@ -186,17 +324,47 @@ impl RunOpt { }; let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + + if let Some(path) = &self.path { + for table in &["lineitem", "supplier", "nation", "customer"] { + let table_path = path.join(table); + if !table_path.exists() { + return exec_err!( + "TPC-H table {} not found at {:?}", + table, + table_path + ); + } + ctx.register_parquet( + *table, + table_path.to_str().unwrap(), + Default::default(), + ) + .await?; + } + } let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { let query_index = query_id - 1; - let sql = HASH_QUERIES[query_index]; + let query = &HASH_QUERIES[query_index]; + + let case_name = format!( + "Query {}_density={}_prob_hit={}_{}*{}", + query_id, + query.density, + query.prob_hit, + query.build_size, + query.probe_size + ); + benchmark_run.start_new_case(&case_name); - benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + let query_run = self + .benchmark_query(query.sql, &query_id.to_string(), &ctx) + .await; match query_run { Ok(query_results) => { for iter in query_results { diff --git a/benchmarks/src/imdb/convert.rs b/benchmarks/src/imdb/convert.rs index e7949aa715c23..aaed186da4905 100644 --- a/benchmarks/src/imdb/convert.rs +++ b/benchmarks/src/imdb/convert.rs @@ -20,31 +20,31 @@ use datafusion::logical_expr::select_expr::SelectExpr; use datafusion_common::instant::Instant; use std::path::PathBuf; +use clap::Args; use datafusion::error::Result; use datafusion::prelude::*; -use structopt::StructOpt; use datafusion::common::not_impl_err; -use super::get_imdb_table_schema; use super::IMDB_TABLES; +use super::get_imdb_table_schema; -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] pub struct ConvertOpt { /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] + #[arg(required = true, short = 'i', long = "input")] input_path: PathBuf, /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] + #[arg(required = true, short = 'o', long = "output")] output_path: PathBuf, /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] + #[arg(short = 'f', long = "format")] file_format: String, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] + #[arg(short = 's', long = "batch-size", default_value = "8192")] batch_size: usize, } diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 11bd424ba6866..ca9710a920517 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -19,16 +19,16 @@ use std::path::PathBuf; use std::sync::Arc; use super::{ - get_imdb_table_schema, get_query_sql, IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, - IMDB_TABLES, + IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, IMDB_TABLES, get_imdb_table_schema, + get_query_sql, }; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -41,8 +41,8 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use clap::Args; use log::info; -use structopt::StructOpt; // hack to avoid `default_value is meaningless for bool` errors type BoolDefaultTrue = bool; @@ -57,41 +57,45 @@ type BoolDefaultTrue = bool; /// [2]: https://event.cwi.nl/da/job/imdb.tgz /// [3]: https://db.in.tum.de/~leis/qo/job.tgz -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number. If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] + #[arg(short = 'f', long = "format", default_value = "csv")] file_format: String, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Whether to disable collection of statistics (and cost based optimizations) or not. - #[structopt(short = "S", long = "disable-statistics")] + #[arg(short = 'S', long = "disable-statistics")] disable_statistics: bool, /// If true then hash join used, if false then sort merge join /// True by default. - #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, + + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, } fn map_query_id_to_str(query_id: usize) -> &'static str { @@ -306,8 +310,10 @@ impl RunOpt { .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // register tables self.register_tables(&ctx).await?; @@ -517,6 +523,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -527,6 +534,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(map_query_id_to_str(query))?; @@ -553,6 +561,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -563,6 +572,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(map_query_id_to_str(query))?; diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 07cffa5ae468e..7c598e65d824c 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -22,6 +22,10 @@ pub mod h2o; pub mod hj; pub mod imdb; pub mod nlj; +pub mod smj; +pub mod sort_pushdown; pub mod sort_tpch; +pub mod sql_benchmark; +pub mod tpcds; pub mod tpch; pub mod util; diff --git a/benchmarks/src/nlj.rs b/benchmarks/src/nlj.rs index 7d1e14f69439c..361cc35ec200c 100644 --- a/benchmarks/src/nlj.rs +++ b/benchmarks/src/nlj.rs @@ -16,11 +16,11 @@ // under the License. use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; use datafusion::physical_plan::execute_stream; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::instant::Instant; -use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; -use structopt::StructOpt; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; use futures::StreamExt; @@ -40,19 +40,19 @@ use futures::StreamExt; /// - Input size: Different combinations of left (build) side and right (probe) /// side sizes /// - Selectivity of join filters -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number (between 1 and 10). If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, } @@ -207,8 +207,8 @@ impl RunOpt { }; let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { @@ -268,8 +268,8 @@ impl RunOpt { let elapsed = start.elapsed(); println!( - "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" - ); + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); query_results.push(QueryResult { elapsed, row_count }); } diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs new file mode 100644 index 0000000000000..3d173b7116e2b --- /dev/null +++ b/benchmarks/src/smj.rs @@ -0,0 +1,647 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; + +use futures::StreamExt; + +/// Run the Sort Merge Join (SMJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of SMJs. +/// +/// It uses equality join predicates (to ensure SMJ is selected) and varies: +/// - Join type: Inner/Left/Right/Full/LeftSemi/LeftAnti/RightSemi/RightAnti +/// - Key cardinality: 1:1, 1:N, N:M relationships +/// - Filter selectivity: Low (1%), Medium (10%), High (50%) +/// - Input sizes: Small to large, balanced and skewed +/// +/// All inputs are pre-sorted in CTEs before the join to isolate join +/// performance from sort overhead. +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 26). If not specified, runs all queries + #[arg(short, long)] + query: Option, + + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[arg(short = 'o', long = "output")] + output_path: Option, +} + +/// Inline SQL queries for SMJ benchmarks +/// +/// Each query's comment includes: +/// - Join type +/// - Left row count × Right row count +/// - Key cardinality (rows per key) +/// - Filter selectivity (if applicable) +const SMJ_QUERIES: &[&str] = &[ + // Q1: INNER 1M x 1M | 1:1 + r#" + WITH t1_sorted AS ( + SELECT value as key FROM range(1000000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key FROM range(1000000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t2_sorted.key as k2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q2: INNER 1M x 10M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q3: INNER 1M x 1M | 1:100 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q4: INNER 1M x 10M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data % 100 = 0 + "#, + // Q5: INNER 1M x 1M | 1:100 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t1_sorted.data <> t2_sorted.data AND t2_sorted.data % 10 = 0 + "#, + // Q6: LEFT 1M x 10M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 105000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q7: LEFT 1M x 10M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 + "#, + // Q8: FULL 1M x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 125000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q9: FULL 1M x 10M | 1:10 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE (t1_sorted.data IS NULL OR t2_sorted.data IS NULL + OR t1_sorted.data <> t2_sorted.data) + AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) + "#, + // Q10: LEFT SEMI 1M x 10M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key + FROM range(10000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q11: LEFT SEMI 1M x 10M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q12: LEFT SEMI 1M x 10M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q13: LEFT SEMI 1M x 10M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 10 <> 0 + ) + "#, + // Q14: LEFT ANTI 1M x 10M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 105000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key + FROM range(10000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q15: LEFT ANTI 1M x 10M | 1:10 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 120000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key + FROM range(10000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q16: LEFT ANTI 1M x 1M | 1:1 | stress + r#" + WITH t1_sorted AS ( + SELECT value % 110000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q17: INNER 1M x 50M | 1:50 | 5% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(50000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data <> t1_sorted.data AND t2_sorted.data % 20 = 0 + "#, + // Q18: LEFT SEMI 1M x 50M | 1:50 | 2% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(50000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 50 = 0 + ) + "#, + // Q19: LEFT ANTI 1M x 50M | 1:50 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 150000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key + FROM range(50000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q20: INNER 1M x 10M | 1:100 + GROUP BY + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, count(*) as cnt + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + GROUP BY t1_sorted.key + "#, + // Q21: INNER 10M x 10M | unique keys (1:1) | 50% join filter + r#" + WITH t1_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted + ON t1_sorted.key = t2_sorted.key + AND t1_sorted.data + t2_sorted.data < 10000000 + "#, + // Q22: LEFT 10M x 10M | unique keys (1:1) | 50% join filter + r#" + WITH t1_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted + ON t1_sorted.key = t2_sorted.key + AND t1_sorted.data + t2_sorted.data < 10000000 + "#, + // Q23: FULL 10M x 10M | unique keys (1:1) | 50% join filter + r#" + WITH t1_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key, value as data + FROM range(10000000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted + ON t1_sorted.key = t2_sorted.key + AND t1_sorted.data + t2_sorted.data < 10000000 + "#, + // Q24: LEFT MARK 1M x 10M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE t1_sorted.data < 0 + OR EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q25: LEFT MARK 1M x 10M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE t1_sorted.data < 0 + OR EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q26: LEFT MARK 1M x 10M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 100000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE t1_sorted.data < 0 + OR EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 10 <> 0 + ) + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running SMJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= SMJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + SMJ_QUERIES.len() + ); + } + } + None => 1..=SMJ_QUERIES.len(), + }; + + let mut config = self.common.config()?; + // Disable hash joins to force SMJ + config = config.set_bool("datafusion.optimizer.prefer_hash_join", false); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = SMJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let expect_mark = query_id >= 24; + let query_run = self + .benchmark_query(sql, &query_id.to_string(), expect_mark, &ctx) + .await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("SMJ benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + expect_mark: bool, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Sort Merge Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("SortMergeJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Sort Merge Join. Physical plan: {plan_string}" + )); + } + + if expect_mark && !plan_string.contains("LeftMark") { + return Err(exec_datafusion_err!( + "Query {query_name} expected LeftMark join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/sort_pushdown.rs b/benchmarks/src/sort_pushdown.rs new file mode 100644 index 0000000000000..e7fce1921e7a8 --- /dev/null +++ b/benchmarks/src/sort_pushdown.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for sort pushdown optimization. +//! +//! Tests performance of sort elimination when files are non-overlapping and +//! internally sorted (declared via `--sorted` / `WITH ORDER`). +//! +//! Queries are loaded from external SQL files under `queries/sort_pushdown/` +//! so they can also be run directly with `datafusion-cli`. +//! +//! # Usage +//! +//! ```text +//! # Prepare sorted TPCH lineitem data (SF=1) +//! ./bench.sh data sort_pushdown +//! +//! # Baseline (no WITH ORDER, full SortExec) +//! ./bench.sh run sort_pushdown +//! +//! # With sort elimination (WITH ORDER, SortExec removed) +//! ./bench.sh run sort_pushdown_sorted +//! ``` + +use clap::Args; +use futures::StreamExt; +use std::path::PathBuf; +use std::sync::Arc; + +use datafusion::datasource::TableProvider; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{displayable, execute_stream}; +use datafusion::prelude::*; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; +use datafusion_common::instant::Instant; + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; + +/// Default path to query files, relative to the benchmark root +const SORT_PUSHDOWN_QUERY_DIR: &str = "queries/sort_pushdown"; + +#[derive(Debug, Args)] +pub struct RunOpt { + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// Sort pushdown query number (1-4). If not specified, runs all queries + #[arg(short, long)] + pub query: Option, + + /// Path to data files (lineitem). Only parquet format is supported. + #[arg(required = true, short = 'p', long = "path")] + path: PathBuf, + + /// Path to JSON benchmark result to be compared using `compare.py` + #[arg(short = 'o', long = "output")] + output_path: Option, + + /// Path to directory containing query SQL files (q1.sql, q2.sql, ...). + /// Defaults to `queries/sort_pushdown/` relative to current directory. + #[arg(long = "queries-path")] + queries_path: Option, + + /// Mark the first column (l_orderkey) as sorted via WITH ORDER. + /// When set, enables sort elimination for matching queries. + #[arg(short = 't', long = "sorted")] + sorted: bool, +} + +impl RunOpt { + const TABLES: [&'static str; 1] = ["lineitem"]; + + fn queries_dir(&self) -> PathBuf { + self.queries_path + .clone() + .unwrap_or_else(|| PathBuf::from(SORT_PUSHDOWN_QUERY_DIR)) + } + + fn load_query(&self, query_id: usize) -> Result { + let path = self.queries_dir().join(format!("q{query_id}.sql")); + std::fs::read_to_string(&path).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Failed to read query file {}: {e}", + path.display() + )) + }) + } + + fn available_queries(&self) -> Vec { + let dir = self.queries_dir(); + let mut ids = Vec::new(); + if let Ok(entries) = std::fs::read_dir(&dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if let Some(rest) = name.strip_prefix('q') + && let Some(num_str) = rest.strip_suffix(".sql") + && let Ok(id) = num_str.parse::() + { + ids.push(id); + } + } + } + ids.sort(); + ids + } + + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let query_ids = match self.query { + Some(query_id) => vec![query_id], + None => self.available_queries(), + }; + + for query_id in query_ids { + benchmark_run.start_new_case(&format!("{query_id}")); + + let query_results = self.benchmark_query(query_id).await; + match query_results { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query(&self, query_id: usize) -> Result> { + let sql = self.load_query(query_id)?; + + let config = self.common.config()?; + let rt = self.common.build_runtime()?; + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(rt) + .with_default_features() + .build(); + let ctx = SessionContext::from(state); + + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let row_count = self.execute_query(&ctx, sql.as_str()).await?; + + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + print_memory_stats(); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::TABLES { + let table_provider = self.get_table(ctx, table).await?; + ctx.register_table(table, table_provider)?; + } + Ok(()) + } + + async fn execute_query(&self, ctx: &SessionContext, sql: &str) -> Result { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + + let mut row_count = 0; + let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + } + + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + } + + Ok(row_count) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(true); // Always collect statistics for sort pushdown + + let table_path = ListingTableUrl::parse(path)?; + let schema = options.infer_schema(&state, &table_path).await?; + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } +} diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 09b5a676bbff1..95c90d826de20 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -21,10 +21,10 @@ //! Another `Sort` benchmark focus on single core execution. This benchmark //! runs end-to-end sort queries and test the performance on multiple CPU cores. +use clap::Args; use futures::StreamExt; use std::path::PathBuf; use std::sync::Arc; -use structopt::StructOpt; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -36,41 +36,41 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion::prelude::*; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] pub struct RunOpt { /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Sort query number. If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] pub query: Option, /// Path to data files (lineitem). Only parquet format is supported - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Path to JSON benchmark result to be compare using `compare.py` - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Mark the first column of each table as sorted in ascending order. /// The tables should have been created with the `--sort` option for this to have any effect. - #[structopt(short = "t", long = "sorted")] + #[arg(short = 't', long = "sorted")] sorted: bool, /// Append a `LIMIT n` clause to the query - #[structopt(short = "l", long = "limit")] + #[arg(short = 'l', long = "limit")] limit: Option, } @@ -209,10 +209,10 @@ impl RunOpt { /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; + let rt = self.common.build_runtime()?; let state = SessionStateBuilder::new() .with_config(config) - .with_runtime_env(rt_builder.build_arc()?) + .with_runtime_env(rt) .with_default_features() .build(); let ctx = SessionContext::from(state); diff --git a/benchmarks/src/sql_benchmark.rs b/benchmarks/src/sql_benchmark.rs new file mode 100644 index 0000000000000..34614b132483f --- /dev/null +++ b/benchmarks/src/sql_benchmark.rs @@ -0,0 +1,3538 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use arrow::util::display::{ArrayFormatter, FormatOptions}; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::MemTable; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::{CsvReadOptions, DataFrame, SessionContext}; +use datafusion_common::config::CsvOptions; +use datafusion_common::{DataFusionError, Result, exec_datafusion_err}; +use futures::StreamExt; +use log::{debug, info, trace, warn}; +use regex::Regex; +use std::collections::HashMap; +use std::fmt::Debug; +use std::fs::{self, File, OpenOptions}; +use std::io::{BufRead, BufReader}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +/// A collection of benchmark configurations and state used by the DataFusion +/// sql test harness. Each benchmark is defined by a file that can contain +/// directives such as `load`, `run`, `assert`, `result`, etc. The +/// `SqlBenchmark` struct holds the parsed data from that file and +/// the impl provides methods to run, assert, persist, verify and cleanup benchmark +/// results. +#[derive(Debug, Clone)] +pub struct SqlBenchmark { + /// Human‑readable name of the benchmark. + name: String, + /// Top‑level group name (derived from the file path or defined in a benchmark). + group: String, + /// Subgroup name, often a logical grouping. + subgroup: String, + /// Full path to the benchmark file. + benchmark_path: PathBuf, + /// Mapping of placeholder keys to concrete values (e.g. `"BENCHMARK_DIR"`). + replacement_mapping: HashMap, + /// Expected string that must appear in the physical plan of the queries. + expect: Vec, + /// All SQL queries grouped by directive (`load`, `run`, etc.). + queries: HashMap>, + /// Queries whose results are persisted to disk for later comparison. + result_queries: Vec, + /// Queries whose results are asserted against an expected table. + assert_queries: Vec, + /// Flag indicating whether the benchmark has been fully loaded + is_loaded: bool, + /// Stores the last run results if needed so they can be compared or persisted. + last_results: Option>, + /// echo statements + echo: Vec, +} + +impl SqlBenchmark { + pub async fn new( + ctx: &SessionContext, + full_path: impl AsRef, + benchmark_directory: impl AsRef, + ) -> Result { + let full_path = full_path.as_ref(); + let benchmark_directory = benchmark_directory.as_ref(); + let group_name = parse_group_from_path(full_path, benchmark_directory); + let mut bm = Self { + name: String::new(), + group: group_name, + subgroup: String::new(), + benchmark_path: full_path.to_path_buf(), + replacement_mapping: HashMap::new(), + expect: vec![], + queries: HashMap::new(), + result_queries: vec![], + assert_queries: vec![], + is_loaded: false, + last_results: None, + echo: vec![], + }; + insert_replacement( + &mut bm.replacement_mapping, + "BENCHMARK_DIR", + benchmark_directory.to_string_lossy().into_owned(), + ); + + let path = bm.benchmark_path.clone(); + bm.process_file(ctx, &path).await?; + + Ok(bm) + } + + /// Initializes the benchmark by executing `load` and `init` queries. + /// + /// Registers any required tables or sets up state in the provided + /// `SessionContext` before running queries. This method is idempotent: + /// calling it multiple times on the same instance returns + /// immediately after the first successful initialization. + /// + /// # Errors + /// Returns an error if any `load` or `init` query fails, or if the + /// benchmark file does not contain a `run` query. + pub async fn initialize(&mut self, ctx: &SessionContext) -> Result<()> { + if self.is_loaded { + return Ok(()); + } + + let path = self.benchmark_path.to_string_lossy().into_owned(); + + // validate there was a run query + if !self.queries.contains_key(&QueryDirective::Run) { + return Err(exec_datafusion_err!( + "Invalid benchmark file: no \"run\" query specified: {path}" + )); + } + + // display any echo's + self.echo.iter().for_each(|txt| println!("{txt}")); + + let load_queries = self.queries.get(&QueryDirective::Load); + + if let Some(queries) = load_queries { + for query in queries { + debug!("Executing load query {query}"); + ctx.sql(query).await?.collect().await?; + } + } + + let init_queries = self.queries.get(&QueryDirective::Init); + + if let Some(queries) = init_queries { + for query in queries { + debug!("Executing init query {query}"); + ctx.sql(query).await?.collect().await?; + } + } + + self.is_loaded = true; + + Ok(()) + } + + /// Executes the `assert` queries and compares actual results against + /// expected values. + /// + /// Each `assert` query must be followed by a result table (separated by + /// `----`) in the benchmark file. The assertion passes only if the + /// returned record batches exactly match the expected rows. + /// + /// # Errors + /// Returns an error if any `assert` query fails, or if the actual and + /// expected results differ in row count or cell values. + pub async fn assert(&mut self, ctx: &SessionContext) -> Result<()> { + info!("Running assertions..."); + + for assert_query in &self.assert_queries { + let query = &assert_query.query; + + info!("Executing assert query {query}"); + + let result = ctx.sql(query).await?.collect().await?; + let formatted_actual_results = format_record_batches(&result)?; + + Self::compare_results( + assert_query, + &formatted_actual_results, + &assert_query.expected_result, + )?; + } + + Ok(()) + } + + /// Executes the `run` queries, optionally saving results for later + /// verification. If there are multiple queries only the results for + /// the last query are saved. + /// + /// When `save_results` is `true`, it collects `SELECT`/`WITH` query + /// results and stores them in `last_results`. + /// + /// When `save_results` is `false`, it streams results and counts rows + /// without buffering them. + /// + /// If an 'expect' string is defined this method also validates that + /// the physical plan contains that string. + /// + /// # Errors + /// Returns an error if a `run` query fails or if expected plan strings + /// are not found. + pub async fn run(&mut self, ctx: &SessionContext, save_results: bool) -> Result<()> { + let run_queries = self + .queries + .get(&QueryDirective::Run) + .ok_or_else(|| exec_datafusion_err!("Run query should be loaded by now"))?; + + let mut result_count = 0; + + let result: Vec = { + let mut local_result = vec![]; + + for query in run_queries { + match save_results { + true => { + debug!( + "Running query (saving results) {}-{}: {query}", + self.group, self.subgroup + ); + + let df = ctx.sql(query).await?; + if !self.expect.is_empty() { + let physical_plan = df.create_physical_plan().await?; + self.validate_expected_plan(&physical_plan)?; + } + + let result_schema = Arc::new(df.schema().as_arrow().clone()); + let mut batches = df.collect().await?; + let trimmed = query.trim_start(); + + // save the output for select/with queries + if starts_with_ignore_ascii_case(trimmed, "select") + || starts_with_ignore_ascii_case(trimmed, "with") + { + if batches.is_empty() { + batches.push(RecordBatch::new_empty(result_schema)); + } + let row_count_for_query = + batches.iter().map(RecordBatch::num_rows).sum::(); + debug!( + "Persisting {} batches ({} rows)...", + batches.len(), + row_count_for_query + ); + + result_count = row_count_for_query; + local_result = batches; + } + } + false => { + debug!( + "Running query (ignoring results) {}-{}: {query}", + self.group, self.subgroup + ); + + result_count = self + .execute_sql_without_result_buffering(query, ctx) + .await?; + } + } + } + + Ok::, DataFusionError>(local_result) + }?; + + debug!("Results have {result_count} rows"); + + // Store results for verification + self.last_results = Some(result); + + Ok(()) + } + + /// Calls run and persists results to disk as a CSV file. + /// + /// Requires that the benchmark defines a `result` or `result_query`. + /// Registers the results in a memory table and writes them to disk with + /// pipe delimiters and a header row. + /// + /// # Errors + /// Returns an error if no results are available or if writing to the + /// target path fails. + pub async fn persist(&mut self, ctx: &SessionContext) -> Result<()> { + self.run(ctx, true).await?; + + // Check if we have result queries to persist for + if self.result_queries.is_empty() { + info!("No result paths to persist"); + return Ok(()); + } + + let results = self + .last_results + .as_ref() + .expect("run should store last_results after successful execution"); + + let query = &self.result_queries[0]; + let path = query.path.as_ref().ok_or_else(|| { + exec_datafusion_err!( + "Unable to persist results from query '{}', no result specified", + query.query + ) + })?; + + info!("Persisting results for query to {path}"); + + let first_batch = results + .first() + .ok_or_else(|| exec_datafusion_err!("Results should be loaded"))?; + + let schema = first_batch.schema(); + let provider = MemTable::try_new(schema, vec![results.clone()])?; + + ctx.register_table("persist_data", Arc::new(provider))?; + + let df = ctx.table("persist_data").await?; + df.write_csv( + path, + DataFrameWriteOptions::new(), + Some( + CsvOptions::default() + .with_delimiter(b'|') + .with_has_header(true), + ), + ) + .await?; + + ctx.deregister_table("persist_data")?; + + Ok(()) + } + + /// Verifies persisted results against expected values. + /// + /// Executes the `result_query` or uses the stored last run results, then + /// compares actual output rows to the expected values defined in the + /// benchmark file. + /// + /// # Errors + /// Returns an error if no results are available or if the actual and + /// expected results differ in count or content. + pub async fn verify(&mut self, ctx: &SessionContext) -> Result<()> { + // Check if we have result queries to verify + if self.result_queries.is_empty() { + return Ok(()); + } + + if self.last_results.is_none() { + return Err(exec_datafusion_err!( + "No results available for verification. Run the benchmark first." + )); + } + + info!("Verifying results..."); + + self.load_expected_result_files(ctx).await?; + + // Get the first result query (assuming only one for now) + let query = &self.result_queries[0]; + let formatted_actual_results = if !query.query.trim().is_empty() { + let results = ctx.sql(&query.query).await?.collect().await?; + format_record_batches(&results) + } else { + let actual_results = self + .last_results + .as_ref() + .expect("last_results should be present after successful run"); + format_record_batches(actual_results) + }?; + + Self::compare_results(query, &formatted_actual_results, &query.expected_result) + } + + /// Runs `cleanup` queries to reset state after the benchmark run. + pub async fn cleanup(&mut self, ctx: &SessionContext) -> Result<()> { + info!("Running cleanup..."); + + let cleanup_queries = self.queries.get(&QueryDirective::Cleanup); + + if let Some(queries) = cleanup_queries { + for query in queries { + let _ = ctx.sql(query).await?.collect().await?; + } + } + + Ok(()) + } + + async fn load_expected_result_files(&mut self, ctx: &SessionContext) -> Result<()> { + for query in &mut self.result_queries { + if query.query.trim().is_empty() { + let Some(path) = query.path.clone() else { + continue; + }; + + let loaded_query = + read_query_from_file(ctx, path, &HashMap::new()).await?; + query.column_count = loaded_query.column_count; + query.expected_result = loaded_query.expected_result; + } + } + + Ok(()) + } + + fn compare_results( + query: &BenchmarkQuery, + actual_results: &[Vec], + expected_results: &[Vec], + ) -> Result<()> { + if actual_results.is_empty() && expected_results.is_empty() { + return Ok(()); + } + + // Compare row count + if actual_results.len() != expected_results.len() { + return Err(exec_datafusion_err!( + "Error in result: expected {} rows but got {} for query {}", + expected_results.len(), + actual_results.len(), + query.query + )); + } + + // Compare values + let zipped = actual_results + .iter() + .enumerate() + .zip(expected_results.iter()); + + for ((row_idx, actual), expected) in zipped { + trace!( + "row {}\nactual: {actual:?}\nexpected: {expected:?}", + row_idx + 1 + ); + + // Compare column count + if actual.len() != expected.len() { + return Err(exec_datafusion_err!( + "Error in result: expected {} columns but got {} for query {}", + expected.len(), + actual.len(), + query.query + )); + } + + for (col_idx, expected_val) in + expected.iter().enumerate().take(query.column_count) + { + // The row-width check above guarantees this index exists. + let actual_val = &actual[col_idx]; + + trace!("actual_val = {actual_val:?}\nexpected_val = {expected_val:?}"); + + if (expected_val == "NULL" && actual_val.is_empty()) + || (expected_val == actual_val) + || (expected_val == "(empty)" + && (actual_val.is_empty() || actual_val == "NULL")) + { + continue; + } + + return Err(exec_datafusion_err!( + "Error in result on row {}, column {} running query \"{}\": expected value \ + \"{expected_val}\" but got value \"{actual_val}\" in row: {actual:?}", + row_idx + 1, + col_idx + 1, + query.query + )); + } + } + + Ok(()) + } + + async fn process_file(&mut self, ctx: &SessionContext, path: &Path) -> Result<()> { + debug!("Processing file {}", path.display()); + + let mut replacement_mapping = self.replacement_mapping.clone(); + insert_replacement( + &mut replacement_mapping, + "FILE_PATH", + path.to_string_lossy().into_owned(), + ); + + let mut reader = BenchmarkFileReader::new(path, replacement_mapping)?; + let mut line = String::with_capacity(1024); + let mut reader_result = reader.read_line(&mut line); + + while let Some(result) = reader_result { + match result { + Ok(_) => { + if !is_blank_or_comment_line(&line) { + // boxing required because of recursion + Box::pin(self.process_line(ctx, &mut reader, &mut line)).await?; + } + } + Err(e) => return Err(e), + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(&mut line); + } + + Ok(()) + } + + async fn process_line( + &mut self, + ctx: &SessionContext, + reader: &mut BenchmarkFileReader, + line: &mut String, + ) -> Result<()> { + // Split the line into directive and arguments. + let cloned_line = line.trim_start().to_string(); + let splits: Vec<&str> = cloned_line.split_whitespace().collect(); + + BenchmarkDirective::select(reader, splits[0])? + .process(ctx, self, reader, line, &splits) + .await + } + + fn process_query(&mut self, splits: &[&str], mut query: String) -> Result<()> { + debug!("Processing query {query}"); + + // Trim and validate. + query = query.trim().to_string(); + if query.is_empty() { + return Ok(()); + } + + // remove comments + query = query + .lines() + .filter(|line| !is_comment_line(line)) + .collect::>() + .join("\n"); + + if query.trim().is_empty() { + return Ok(()); + } + + query = process_replacements(&query, self.replacement_mapping())?; + + let directive = QueryDirective::parse(splits[0]).ok_or_else(|| { + exec_datafusion_err!("Invalid query directive: {}", splits[0]) + })?; + + self.queries.entry(directive).or_default().push(query); + + Ok(()) + } + + fn validate_expected_plan(&self, physical_plan: &impl Debug) -> Result<()> { + if self.expect.is_empty() { + return Ok(()); + } + + let plan_string = format!("{physical_plan:#?}"); + + for exp_str in &self.expect { + if !plan_string.contains(exp_str) { + return Err(exec_datafusion_err!( + "The query physical plan does not contain the expected string '{exp_str}'. Physical plan: {plan_string}" + )); + } + } + + Ok(()) + } + + async fn execute_sql_without_result_buffering( + &self, + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + + self.validate_expected_plan(&physical_plan)?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn group(&self) -> &str { + &self.group + } + + pub fn subgroup(&self) -> &str { + &self.subgroup + } + + pub fn benchmark_path(&self) -> &Path { + &self.benchmark_path + } + + pub fn replacement_mapping(&self) -> &HashMap { + &self.replacement_mapping + } + + pub fn queries(&self) -> &HashMap> { + &self.queries + } + + pub fn result_queries(&self) -> &[BenchmarkQuery] { + &self.result_queries + } + + pub fn assert_queries(&self) -> &[BenchmarkQuery] { + &self.assert_queries + } + + pub fn is_loaded(&self) -> bool { + self.is_loaded + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum QueryDirective { + Load, + Run, + Init, + Cleanup, +} + +impl QueryDirective { + fn parse(value: &str) -> Option { + if value.eq_ignore_ascii_case("load") { + Some(Self::Load) + } else if value.eq_ignore_ascii_case("init") { + Some(Self::Init) + } else if value.eq_ignore_ascii_case("run") { + Some(Self::Run) + } else if value.eq_ignore_ascii_case("cleanup") { + Some(Self::Cleanup) + } else { + None + } + } + + fn as_str(self) -> &'static str { + match self { + Self::Load => "load", + Self::Run => "run", + Self::Init => "init", + Self::Cleanup => "cleanup", + } + } +} + +enum BenchmarkDirective { + Load, + Run, + Init, + Cleanup, + Name, + Group, + Subgroup, + Expect, + Assert, + ResultQuery, + Results, + Template, + Include, + Echo, +} + +impl BenchmarkDirective { + fn select( + reader: &BenchmarkFileReader, + directive: &str, + ) -> Result { + if directive.eq_ignore_ascii_case("load") { + Ok(BenchmarkDirective::Load) + } else if directive.eq_ignore_ascii_case("run") { + Ok(BenchmarkDirective::Run) + } else if directive.eq_ignore_ascii_case("init") { + Ok(BenchmarkDirective::Init) + } else if directive.eq_ignore_ascii_case("cleanup") { + Ok(BenchmarkDirective::Cleanup) + } else if directive.eq_ignore_ascii_case("name") { + Ok(BenchmarkDirective::Name) + } else if directive.eq_ignore_ascii_case("group") { + Ok(BenchmarkDirective::Group) + } else if directive.eq_ignore_ascii_case("subgroup") { + Ok(BenchmarkDirective::Subgroup) + } else if directive.eq_ignore_ascii_case("expect_plan") { + Ok(BenchmarkDirective::Expect) + } else if directive.eq_ignore_ascii_case("assert") { + Ok(BenchmarkDirective::Assert) + } else if directive.eq_ignore_ascii_case("result_query") { + Ok(BenchmarkDirective::ResultQuery) + } else if directive.eq_ignore_ascii_case("result") { + Ok(BenchmarkDirective::Results) + } else if directive.eq_ignore_ascii_case("template") { + Ok(BenchmarkDirective::Template) + } else if directive.eq_ignore_ascii_case("include") { + Ok(BenchmarkDirective::Include) + } else if directive.eq_ignore_ascii_case("echo") { + Ok(BenchmarkDirective::Echo) + } else { + Err(exec_datafusion_err!( + "{}", + reader.format_exception(&format!("Unrecognized command: {directive}")) + )) + } + } + + async fn process( + &self, + ctx: &SessionContext, + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &mut String, + splits: &[&str], + ) -> Result<()> { + trace!("-- handling {}", splits[0]); + + match self { + BenchmarkDirective::Load + | BenchmarkDirective::Run + | BenchmarkDirective::Init + | BenchmarkDirective::Cleanup => { + Self::process_query_directive(bench, reader, line, splits) + } + BenchmarkDirective::Name => Self::process_metadata_value( + bench, + reader, + line, + "name", + "BENCH_NAME", + "name must be followed by a value", + ), + BenchmarkDirective::Group => Self::process_metadata_value( + bench, + reader, + line, + "group", + "BENCH_GROUP", + "group must be followed by a value", + ), + BenchmarkDirective::Subgroup => Self::process_metadata_value( + bench, + reader, + line, + "subgroup", + "BENCH_SUBGROUP", + "subgroup must be followed by a value", + ), + BenchmarkDirective::Expect => Self::process_expect(bench, reader, splits), + BenchmarkDirective::Assert => { + Self::process_assert(bench, reader, line, splits) + } + BenchmarkDirective::ResultQuery => { + Self::process_result_query(bench, reader, line, splits) + } + BenchmarkDirective::Results => Self::process_results(bench, reader, splits), + BenchmarkDirective::Template => { + Self::process_template(ctx, bench, reader, line, splits).await + } + BenchmarkDirective::Include => { + Self::process_include(ctx, bench, reader, splits).await + } + BenchmarkDirective::Echo => Self::process_echo(bench, reader, splits), + } + } + + fn process_query_directive( + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &mut String, + splits: &[&str], + ) -> Result<()> { + let directive = QueryDirective::parse(splits[0]).ok_or_else(|| { + exec_datafusion_err!("Invalid query directive: {}", splits[0]) + })?; + + if directive == QueryDirective::Run && bench.queries.contains_key(&directive) { + return Err(exec_datafusion_err!( + "Multiple calls to run in the same benchmark file" + )); + } + + line.clear(); + + // Read the query body until a blank line or EOF. + let mut query = String::new(); + let mut reader_result = reader.read_line(line); + + loop { + match reader_result { + Some(Ok(_)) => { + if is_comment_line(line) { + // comment, ignore + } else if is_blank_line(line) { + break; + } else { + query.push_str(line); + query.push('\n'); + } + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(line); + } + + // Optional file parameter. + if splits.len() > 1 && !splits[1].is_empty() { + if !query.trim().is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception(&format!( + "{} directive must use either a query file or inline SQL, not both", + directive.as_str() + )) + )); + } + + debug!("Processing {} file: {}", splits[0], splits[1]); + + let query_file = fs::read_to_string(splits[1]).map_err(|e| { + exec_datafusion_err!("Failed to read query file {}: {e}", splits[1]) + })?; + let query_file = query_file.replace("\r\n", "\n"); + + // some files have multiple queries, split apart + for query in split_query_statements(&query_file) { + bench.process_query(splits, query.to_string())?; + } + } else if directive == QueryDirective::Run { + for query in split_query_statements(&query) { + bench.process_query(splits, query.to_string())?; + } + } else { + bench.process_query(splits, query)?; + } + + Ok(()) + } + + fn process_metadata_value( + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &str, + directive: &str, + replacement_key: &str, + message: &str, + ) -> Result<()> { + let value = + directive_value(reader, line.trim_start(), directive, message)?.to_string(); + + match directive { + "name" => bench.name.clone_from(&value), + "group" => bench.group.clone_from(&value), + "subgroup" => bench.subgroup.clone_from(&value), + _ => unreachable!("unsupported metadata directive: {directive}"), + } + + insert_replacement( + &mut bench.replacement_mapping, + replacement_key, + value.clone(), + ); + insert_replacement(&mut reader.replacements, replacement_key, value); + + Ok(()) + } + + fn process_expect( + bench: &mut SqlBenchmark, + reader: &BenchmarkFileReader, + splits: &[&str], + ) -> Result<()> { + trace!("-- handling {}", splits[0]); + + if splits.len() <= 1 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "expect_plan must be followed by a string to search in the physical plan" + ) + )); + } + + bench.expect.push(splits[1..].join(" ").to_string()); + + Ok(()) + } + + fn process_assert( + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &mut String, + splits: &[&str], + ) -> Result<()> { + // count the amount of columns based on character count. The actual + // character used is irrelevant. + if splits.len() <= 1 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "assert must be followed by a column count (e.g. assert III)" + ) + )); + } + + line.clear(); + + // read the actual query + let mut found_break = false; + let mut sql = String::new(); + let mut reader_result = reader.read_line(line); + + loop { + match reader_result { + Some(Ok(_)) => { + if line.trim() == "----" { + found_break = true; + break; + } + sql.push('\n'); + sql.push_str(line); + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(line); + } + + if !found_break { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "assert must be followed by a query and a result (separated by ----)" + ) + )); + } + + bench + .assert_queries + .push(read_query_from_reader(reader, &sql, splits[1])?); + + Ok(()) + } + + fn process_results( + bench: &mut SqlBenchmark, + reader: &BenchmarkFileReader, + splits: &[&str], + ) -> Result<()> { + if splits.len() <= 1 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "result must be followed by a path to a result file" + ) + )); + } + + if !bench.result_queries.is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception("multiple results found") + )); + } + + let path = process_replacements(splits[1], &bench.replacement_mapping)?; + + bench.result_queries.push(BenchmarkQuery { + path: Some(path), + query: String::new(), + column_count: 0, + expected_result: vec![], + }); + + Ok(()) + } + + fn process_result_query( + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &mut String, + splits: &[&str], + ) -> Result<()> { + if splits.len() <= 1 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "result_query must be followed by a column count (e.g. result_query III)" + ) + )); + } + + line.clear(); + + let mut sql = String::new(); + let mut found_break = false; + let mut reader_result = reader.read_line(line); + + loop { + match reader_result { + Some(Ok(_)) => { + if line.trim() == "----" { + found_break = true; + break; + } + sql.push_str(line); + sql.push('\n'); + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(line); + } + + if !found_break { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "result_query must be followed by a query and a result (separated by ----)" + ) + )); + } + + let result_check = read_query_from_reader(reader, &sql, splits[1])?; + + if !bench.result_queries.is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception("multiple results found") + )); + } + bench.result_queries.push(result_check); + + Ok(()) + } + + async fn process_template( + ctx: &SessionContext, + bench: &mut SqlBenchmark, + reader: &mut BenchmarkFileReader, + line: &mut String, + splits: &[&str], + ) -> Result<()> { + if splits.len() != 2 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception("template requires a single template path") + )); + } + + // template: update the path to read + bench.benchmark_path = PathBuf::from(splits[1]); + + line.clear(); + + // now read parameters + let mut reader_result = reader.read_line(line); + + loop { + match reader_result { + Some(Ok(_)) => { + if is_comment_line(line) { + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(line); + continue; + } + if is_blank_line(line) { + break; + } + + let Some((key, value)) = line.trim_start().split_once('=') else { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception( + "Expected a template parameter in the form of X=Y" + ) + )); + }; + insert_replacement( + &mut bench.replacement_mapping, + key.trim(), + value.trim().to_string(), + ); + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(line); + } + + // restart the load from the template file + Box::pin(bench.process_file(ctx, Path::new(splits[1]))).await + } + + async fn process_include( + ctx: &SessionContext, + bench: &mut SqlBenchmark, + reader: &BenchmarkFileReader, + splits: &[&str], + ) -> Result<()> { + if splits.len() != 2 || splits[1].is_empty() { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception("include requires a single argument") + )); + } + + Box::pin(bench.process_file(ctx, Path::new(splits[1]))).await + } + + fn process_echo( + bench: &mut SqlBenchmark, + reader: &BenchmarkFileReader, + splits: &[&str], + ) -> Result<()> { + if splits.len() < 2 { + return Err(exec_datafusion_err!( + "{}", + reader.format_exception("Echo requires an argument") + )); + } + + bench.echo.push(splits[1..].join(" ")); + + Ok(()) + } +} + +struct BenchmarkFileReader { + path: PathBuf, + reader: BufReader, + line_nr: usize, + replacements: HashMap, +} + +impl BenchmarkFileReader { + fn new>( + path: P, + replacements: HashMap, + ) -> Result { + let path = path.into(); + let file = OpenOptions::new().read(true).open(&path)?; + + Ok(Self { + path, + reader: BufReader::new(file), + line_nr: 0, + replacements, + }) + } + + /// Read the next line, applying replacements and removing line terminators. + fn read_line(&mut self, line: &mut String) -> Option> { + match self.reader.read_line(line) { + Ok(0) => None, + Ok(_) => { + self.line_nr += 1; + + // Trim newline and carriage return without changing other content. + let trimmed_len = line.trim_end_matches(['\n', '\r']).len(); + line.truncate(trimmed_len); + + match process_replacements(line, &self.replacements) { + Ok(l) => { + *line = l; + Some(Ok(())) + } + Err(error) => Some(Err(error)), + } + } + Err(e) => Some(Err(e.into())), + } + } + + fn format_exception(&self, msg: &str) -> String { + format!("{}:{} - {}", self.path.display(), self.line_nr, msg) + } +} + +#[derive(Debug, Clone)] +pub struct BenchmarkQuery { + path: Option, + query: String, + column_count: usize, + expected_result: Vec>, +} + +// ---- utility function below + +fn directive_value<'a>( + reader: &BenchmarkFileReader, + line: &'a str, + directive: &str, + message: &str, +) -> Result<&'a str> { + let value = line + .get(..directive.len()) + .filter(|prefix| prefix.eq_ignore_ascii_case(directive)) + .and_then(|_| line.get(directive.len()..)) + .map(str::trim) + .filter(|s| !s.is_empty()) + .ok_or_else(|| exec_datafusion_err!("{}", reader.format_exception(message)))?; + + Ok(value) +} + +fn parse_group_from_path(path: &Path, benchmark_directory: &Path) -> String { + let mut group_name = String::new(); + let mut parent = path.parent(); + + while let Some(p) = parent { + if path_ends_with_ignore_ascii_case(p, benchmark_directory) { + break; + } + + if let Some(dir_name) = p.file_name() { + group_name = dir_name.to_string_lossy().into_owned(); + } + + parent = p.parent(); + } + + if group_name.is_empty() { + warn!("Unable to find group name in path: {}", path.display()); + } + + group_name +} + +fn path_ends_with_ignore_ascii_case(path: &Path, suffix: &Path) -> bool { + let mut path_components = path.components().rev(); + + for suffix_component in suffix.components().rev() { + let Some(path_component) = path_components.next() else { + return false; + }; + + if !path_component + .as_os_str() + .to_string_lossy() + .eq_ignore_ascii_case(&suffix_component.as_os_str().to_string_lossy()) + { + return false; + } + } + + true +} + +fn starts_with_ignore_ascii_case(input: &str, prefix: &str) -> bool { + input + .get(..prefix.len()) + .is_some_and(|value| value.eq_ignore_ascii_case(prefix)) +} + +fn split_query_statements(sql: &str) -> impl Iterator { + sql.split("\n\n") + .flat_map(|query| { + query + .split_inclusive(";\n") + .map(|part| part.trim_end_matches('\n')) + }) + .filter(|query| !query.trim().is_empty()) +} + +fn is_blank_line(line: &str) -> bool { + line.trim().is_empty() +} + +fn is_comment_line(line: &str) -> bool { + let line = line.trim_start(); + line.starts_with('#') || line.starts_with("--") +} + +fn is_blank_or_comment_line(line: &str) -> bool { + is_blank_line(line) || is_comment_line(line) +} + +fn insert_replacement( + replacement_map: &mut HashMap, + key: &str, + value: String, +) { + replacement_map.insert(key.to_lowercase(), value); +} + +fn replace_all( + re: &Regex, + haystack: &str, + replacement: impl Fn(®ex::Captures) -> Result, +) -> Result { + let mut new = String::with_capacity(haystack.len()); + let mut last_match = 0; + + for caps in re.captures_iter(haystack) { + let m = caps.get(0).unwrap(); + + new.push_str(&haystack[last_match..m.start()]); + new.push_str(&replacement(&caps)?); + + last_match = m.end(); + } + + new.push_str(&haystack[last_match..]); + + Ok(new) +} + +static TRUE_FALSE_REPLACEMENT_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"\$\{(\w+)(?::-([^|}]+))?\|([^|]+)\|([^}]+)}") + .expect("Regex failed to compile") +}); + +static VARIABLE_REPLACEMENT_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"\$\{(\w+)(?::-([^}]+))?}").expect("Regex failed to compile") +}); + +/// Replace all `${KEY}` or `${KEY:-default}` placeholders in a string according to the mapping. +/// Also handles `${KEY:-default|True value|false value}` syntax. +fn process_replacements( + input: &str, + replacement_map: &HashMap, +) -> Result { + process_replacements_with_env(input, replacement_map, |key| std::env::var(key).ok()) +} + +fn process_replacements_with_env( + input: &str, + replacement_map: &HashMap, + get_env: impl Fn(&str) -> Option, +) -> Result { + debug!("processing replacements for line '{input}'"); + + // handle ${VAR:-default|true value|false value} syntax + let replacement = |caps: ®ex::Captures| -> Result { + let key = &caps[1]; + let default = caps.get(2).map(|m| m.as_str().to_string()); + let true_val = &caps[3]; + let false_val = &caps[4]; + + let value = lookup_replacement_value(key, replacement_map, &get_env).or(default); + + match value { + Some(v) if v.eq_ignore_ascii_case("true") => Ok(true_val.to_string()), + Some(_) => Ok(false_val.to_string()), + None => Err(exec_datafusion_err!("Missing value for key '{key}'")), + } + }; + let input = replace_all(&TRUE_FALSE_REPLACEMENT_RE, input, replacement)?; + + // handle ${KEY} and ${KEY:-default}` + let replacement = |caps: ®ex::Captures| -> Result { + let key = &caps[1]; + let default = caps.get(2); + + if let Some(v) = lookup_replacement_value(key, replacement_map, &get_env) { + return Ok(v.to_string()); + } + + // use default if it was set + if let Some(def) = default { + Ok(def.as_str().to_string()) + } else { + Err(exec_datafusion_err!("Missing value for key '{key}'")) + } + }; + + replace_all(&VARIABLE_REPLACEMENT_RE, &input, replacement) +} + +fn lookup_replacement_value( + key: &str, + replacement_map: &HashMap, + get_env: &impl Fn(&str) -> Option, +) -> Option { + if let Some(v) = replacement_map.get(&key.to_lowercase()) { + return Some(v.to_string()); + } + + // look in env variables + get_env(&key.to_uppercase()) +} + +fn read_query_from_reader( + reader: &mut BenchmarkFileReader, + sql: &str, + header: &str, +) -> Result { + let column_count = header.len(); + let mut expected_result = vec![]; + let mut line = String::new(); + let mut reader_result = reader.read_line(&mut line); + + loop { + match reader_result { + Some(Ok(_)) => { + if is_comment_line(&line) { + // comment, ignore + } else if is_blank_line(&line) { + break; + } else { + let result_splits: Vec<&str> = line.split(['\t', '|']).collect(); + + if result_splits.len() != column_count { + return Err(exec_datafusion_err!( + "{} {line}", + reader.format_exception(&format!( + "expected {} values but got {}", + column_count, + result_splits.len(), + )) + )); + } + + expected_result + .push(result_splits.into_iter().map(|s| s.to_string()).collect()); + } + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Clear the line buffer for the next iteration. + line.clear(); + reader_result = reader.read_line(&mut line); + } + + Ok(BenchmarkQuery { + path: None, + query: sql.to_string(), + column_count, + expected_result, + }) +} + +async fn read_query_from_file( + ctx: &SessionContext, + path: impl AsRef, + replacement_mapping: &HashMap, +) -> Result { + // Process replacements in file path + let path = path.as_ref().to_string_lossy(); + let path = process_replacements(&path, replacement_mapping)?; + let df: DataFrame = ctx + .read_csv( + path.clone(), + CsvReadOptions::new() + .has_header(true) + .delimiter(b'|') + .null_regex(Some("NULL".to_string())) + // we only want string values, we do not want to infer the schema + .schema_infer_max_records(0), + ) + .await?; + + // Get schema to determine column count + let schema = df.schema(); + let column_count = schema.fields().len(); + + if column_count == 0 { + return Err(exec_datafusion_err!( + "Result file {path} did not contain any columns" + )); + } + + // Execute and collect results + let batches = df.collect().await?; + // Convert record batches to string vectors + let expected_result = format_record_batches(&batches)?; + + Ok(BenchmarkQuery { + path: Some(path), + query: String::new(), + column_count, + expected_result, + }) +} + +fn format_record_batches( + batches: &[RecordBatch], +) -> Result>, DataFusionError> { + let mut expected_result = vec![]; + let arrow_format_options = FormatOptions::default() + .with_null("NULL") + .with_display_error(true); + + for batch in batches { + let schema = batch.schema_ref(); + + let formatters = batch + .columns() + .iter() + .zip(schema.fields().iter()) + .map(|(c, field)| make_array_formatter(c, &arrow_format_options, Some(field))) + .collect::, ArrowError>>()?; + + for row in 0..batch.num_rows() { + let mut cells = vec![]; + for formatter in &formatters { + cells.push(formatter.value(row).to_string()); + } + expected_result.push(cells); + } + } + + Ok(expected_result) +} + +fn make_array_formatter<'a>( + array: &'a dyn Array, + options: &FormatOptions<'a>, + field: Option<&'a Field>, +) -> Result, ArrowError> { + match options.formatter_factory() { + None => ArrayFormatter::try_new(array, options), + Some(formatters) => formatters + .create_array_formatter(array, options, field) + .transpose() + .unwrap_or_else(|| ArrayFormatter::try_new(array, options)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::prelude::SessionContext; + use std::fs; + use std::path::{Path, PathBuf}; + use tempfile::{TempDir, tempdir}; + + fn write_test_file(temp_dir: &TempDir, name: &str, contents: &str) -> PathBuf { + let path = temp_dir.path().join(name); + fs::write(&path, contents).expect("failed to write benchmark test file"); + path + } + + async fn parse_benchmark_file(path: &Path) -> Result { + let ctx = SessionContext::new(); + let path_string = path.to_string_lossy().into_owned(); + SqlBenchmark::new(&ctx, &path_string, "/tmp").await + } + + async fn parse_benchmark(contents: &str) -> Result { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let path = write_test_file(&temp_dir, "parser.benchmark", contents); + + parse_benchmark_file(&path).await + } + + async fn assert_parse_error(contents: &str, expected_message: &str) { + let error = parse_benchmark(contents) + .await + .expect_err("benchmark parsing should fail"); + + let message = error.to_string(); + assert!( + message.contains(expected_message), + "expected error containing {expected_message:?}, got {message:?}" + ); + } + + fn assert_result_error_contains(result: Result, expected_message: &str) { + let error = result.expect_err("operation should fail"); + let message = error.to_string(); + assert!( + message.contains(expected_message), + "expected error containing {expected_message:?}, got {message:?}" + ); + } + + fn formatted_last_results(benchmark: &SqlBenchmark) -> Vec> { + format_record_batches( + benchmark + .last_results + .as_ref() + .expect("last results should be set"), + ) + .expect("results should format") + } + + fn read_all_files_in_dir(path: &Path) -> String { + let mut entries = fs::read_dir(path) + .expect("directory should be readable") + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|path| path.is_file()) + .collect::>(); + entries.sort(); + + let mut contents = String::new(); + for path in entries { + contents + .push_str(&fs::read_to_string(path).expect("file should be readable")); + } + contents + } + + fn replacement_map(entries: &[(&str, &str)]) -> HashMap { + let mut replacements = HashMap::new(); + for (key, value) in entries { + insert_replacement(&mut replacements, key, value.to_string()); + } + replacements + } + + fn env_map(entries: &[(&str, &str)]) -> HashMap { + entries + .iter() + .map(|(key, value)| (key.to_string(), value.to_string())) + .collect() + } + + // Replacement tests cover benchmark variable expansion syntax. + + #[test] + fn process_replacements_replaces_map_values_case_insensitively() { + let replacements = replacement_map(&[ + ("BENCH_NAME", "tpch"), + ("QUERY_NUMBER_PADDED", "01"), + ("format_1", "parquet"), + ]); + + let actual = process_replacements_with_env( + "${bench_name}/q${query_number_padded}.${FORMAT_1}", + &replacements, + |_| None, + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "tpch/q01.parquet"); + } + + #[test] + fn process_replacements_uses_env_when_map_value_is_missing() { + let replacements = HashMap::new(); + let env = env_map(&[("DATA_DIR", "/tmp/data")]); + + let actual = process_replacements_with_env( + "${data_dir}/lineitem.parquet", + &replacements, + |key| env.get(key).cloned(), + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "/tmp/data/lineitem.parquet"); + } + + #[test] + fn process_replacements_prefers_map_over_env() { + let replacements = replacement_map(&[("BENCH_SIZE", "10")]); + let env = env_map(&[("BENCH_SIZE", "100")]); + + let actual = + process_replacements_with_env("sf${BENCH_SIZE}", &replacements, |key| { + env.get(key).cloned() + }) + .expect("replacement should succeed"); + + assert_eq!(actual, "sf10"); + } + + #[test] + fn process_replacements_uses_default_for_missing_variable() { + let replacements = HashMap::new(); + + let actual = process_replacements_with_env( + "load_${BENCH_SUBGROUP:-groupby}_${FILE_TYPE:-csv}.sql", + &replacements, + |_| None, + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_groupby_csv.sql"); + } + + #[test] + fn process_replacements_reports_missing_variable_without_default() { + let replacements = HashMap::new(); + + let error = process_replacements_with_env("${MISSING}", &replacements, |_| None) + .expect_err("replacement should fail"); + + assert!( + error + .to_string() + .contains("Missing value for key 'MISSING'"), + "unexpected error: {error}" + ); + } + + #[test] + fn process_replacements_applies_true_false_true_branch() { + let replacements = HashMap::new(); + let env = env_map(&[("USE_PARQUET", "TrUe")]); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-false|parquet|csv}.sql", + &replacements, + |key| env.get(key).cloned(), + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_parquet.sql"); + } + + #[test] + fn process_replacements_applies_true_false_false_branch() { + let replacements = HashMap::new(); + let env = env_map(&[("USE_PARQUET", "false")]); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-true|parquet|csv}.sql", + &replacements, + |key| env.get(key).cloned(), + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_csv.sql"); + } + + #[test] + fn process_replacements_uses_map_for_true_false_branch() { + let replacements = replacement_map(&[("USE_PARQUET", "true")]); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-false|parquet|csv}.sql", + &replacements, + |_| None, + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_parquet.sql"); + } + + #[test] + fn process_replacements_prefers_map_over_env_for_true_false_branch() { + let replacements = replacement_map(&[("USE_PARQUET", "false")]); + let env = env_map(&[("USE_PARQUET", "true")]); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-true|parquet|csv}.sql", + &replacements, + |key| env.get(key).cloned(), + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_csv.sql"); + } + + #[test] + fn process_replacements_uses_true_false_default_for_missing_true_value() { + let replacements = HashMap::new(); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-true|parquet|csv}.sql", + &replacements, + |_| None, + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_parquet.sql"); + } + + #[test] + fn process_replacements_uses_true_false_default_for_missing_false_value() { + let replacements = HashMap::new(); + + let actual = process_replacements_with_env( + "load_${USE_PARQUET:-false|parquet|csv}.sql", + &replacements, + |_| None, + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "load_csv.sql"); + } + + #[test] + fn process_replacements_reports_missing_true_false_variable_without_default() { + let replacements = HashMap::new(); + + let error = process_replacements_with_env( + "load_${USE_PARQUET|parquet|csv}.sql", + &replacements, + |_| None, + ) + .expect_err("replacement should fail"); + + assert!( + error + .to_string() + .contains("Missing value for key 'USE_PARQUET'"), + "unexpected error: {error}" + ); + } + + #[test] + fn process_replacements_resolves_variables_after_true_false_replacement() { + let replacements = replacement_map(&[("FILE_TYPE", "parquet")]); + let env = env_map(&[("USE_TYPED_PATH", "true")]); + + let actual = process_replacements_with_env( + "${USE_TYPED_PATH:-false|data.${FILE_TYPE}|data.csv}", + &replacements, + |key| env.get(key).cloned(), + ) + .expect("replacement should succeed"); + + assert_eq!(actual, "data.parquet"); + } + + #[test] + fn process_replacements_leaves_unsupported_placeholder_syntax_unchanged() { + let replacements = HashMap::new(); + + let actual = + process_replacements_with_env("${BAD-KEY:-fallback}", &replacements, |_| { + None + }) + .expect("unsupported placeholder should not match replacement regex"); + + assert_eq!(actual, "${BAD-KEY:-fallback}"); + } + + // Parser tests cover benchmark directives and parse-time validation. + + #[tokio::test] + async fn parser_accepts_metadata_expect_echo_and_sql_sections() { + let benchmark = parse_benchmark( + r#" +# top-level comments are ignored +name Parser Success +group Parser Group +subgroup Parser Subgroup +expect_plan ProjectionExec with details +echo hello from parser + +load +-- query comments are ignored +CREATE TABLE t AS VALUES (1); + +init +CREATE VIEW v AS SELECT * FROM t; + +run +SELECT * FROM v; + +cleanup +DROP VIEW v; +"#, + ) + .await + .expect("benchmark should parse"); + + assert_eq!(benchmark.name(), "Parser Success"); + assert_eq!(benchmark.group(), "Parser Group"); + assert_eq!(benchmark.subgroup(), "Parser Subgroup"); + assert_eq!(benchmark.expect, vec!["ProjectionExec with details"]); + assert_eq!(benchmark.echo, vec!["hello from parser"]); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Load) + .expect("load query"), + &vec!["CREATE TABLE t AS VALUES (1);".to_string()] + ); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Init) + .expect("init query"), + &vec!["CREATE VIEW v AS SELECT * FROM t;".to_string()] + ); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["SELECT * FROM v;".to_string()] + ); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Cleanup) + .expect("cleanup query"), + &vec!["DROP VIEW v;".to_string()] + ); + } + + #[tokio::test] + async fn parser_splits_inline_run_block_on_semicolon_newline() { + let benchmark = parse_benchmark( + r#" +run +CREATE TABLE t AS SELECT 1 AS value; +SELECT value + 1 AS value FROM t; +DROP TABLE t; +"#, + ) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec![ + "CREATE TABLE t AS SELECT 1 AS value;".to_string(), + "SELECT value + 1 AS value FROM t;".to_string(), + "DROP TABLE t;".to_string(), + ] + ); + } + + #[tokio::test] + async fn parser_accepts_assert_with_expected_rows() { + let benchmark = parse_benchmark( + r#" +assert II +select 1, 'one' +---- +1|one +2 two +"#, + ) + .await + .expect("benchmark should parse"); + + let query = benchmark + .assert_queries() + .first() + .expect("assert query should be parsed"); + + assert_eq!(query.column_count, 2); + assert!(query.query.contains("select 1, 'one'")); + assert_eq!( + query.expected_result, + vec![ + vec!["1".to_string(), "one".to_string()], + vec!["2".to_string(), "two".to_string()] + ] + ); + } + + #[tokio::test] + async fn parser_accepts_result_query_with_expected_rows() { + let benchmark = parse_benchmark( + r#" +result_query II +select 1, 'one' +---- +1|one +NULL|(empty) +"#, + ) + .await + .expect("benchmark should parse"); + + let query = benchmark + .result_queries() + .first() + .expect("result query should be parsed"); + + assert_eq!(query.path, None); + assert_eq!(query.column_count, 2); + assert!(query.query.contains("select 1, 'one'")); + assert_eq!( + query.expected_result, + vec![ + vec!["1".to_string(), "one".to_string()], + vec!["NULL".to_string(), "(empty)".to_string()] + ] + ); + } + + #[tokio::test] + async fn parser_records_result_file_without_parsing_contents() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = + write_test_file(&temp_dir, "result.csv", "col_a|col_b\n1|one\nNULL|two\n"); + let benchmark_path = write_test_file( + &temp_dir, + "result.benchmark", + &format!("result {}\n", result_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + let query = benchmark + .result_queries() + .first() + .expect("result file should be parsed"); + + assert_eq!(query.path, Some(result_path.to_string_lossy().into_owned())); + assert_eq!(query.column_count, 0); + assert!(query.expected_result.is_empty()); + } + + #[tokio::test] + async fn parser_accepts_include_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let include_path = + write_test_file(&temp_dir, "include.benchmark", "run\nselect 1\n"); + + let benchmark_path = write_test_file( + &temp_dir, + "include_driver.benchmark", + &format!("include {}\n", include_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + let benchmark = result.expect("benchmark should parse"); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["select 1".to_string()] + ); + } + + #[tokio::test] + async fn parser_accepts_template_file_with_parameters() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let template_path = write_test_file( + &temp_dir, + "template_success.benchmark", + "# template comments are ignored\nrun\n-- query comments are ignored\nselect '${TABLE_NAME}', '${BENCHMARK_DIR}'\n", + ); + + let benchmark_path = write_test_file( + &temp_dir, + "template_success_driver.benchmark", + &format!( + "template {}\n# parameter comments are ignored\nTABLE_NAME=orders\n", + template_path.display() + ), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + let benchmark = result.expect("benchmark should parse"); + assert_eq!(benchmark.benchmark_path(), template_path.as_path()); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["select 'orders', '/tmp'".to_string()] + ); + } + + #[tokio::test] + async fn parser_trims_template_parameter_keys_and_values() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let template_path = write_test_file( + &temp_dir, + "template_trim.benchmark", + "run\nselect '${TABLE_NAME}'\n", + ); + + let benchmark_path = write_test_file( + &temp_dir, + "template_trim_driver.benchmark", + &format!( + "template {}\n TABLE_NAME = orders \n", + template_path.display() + ), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["select 'orders'".to_string()] + ); + assert_eq!( + benchmark.replacement_mapping().get("table_name"), + Some(&"orders".to_string()) + ); + } + + #[tokio::test] + async fn parser_preserves_expected_result_cell_whitespace() { + let benchmark = parse_benchmark("assert I\nselect ' x '\n----\n x \n") + .await + .expect("benchmark should parse"); + + let query = benchmark + .assert_queries() + .first() + .expect("assert query should be parsed"); + + assert_eq!(query.expected_result, vec![vec![" x ".to_string()]]); + } + + #[tokio::test] + async fn parser_accepts_indented_comments_and_blank_lines() { + let benchmark = + parse_benchmark(" # comment\n -- comment\n run\n select 1\n \n") + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["select 1".to_string()] + ); + } + + #[tokio::test] + async fn parser_accepts_case_insensitive_query_directives() { + let benchmark = parse_benchmark("RUN\nselect 1\n") + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["select 1".to_string()] + ); + } + + #[tokio::test] + async fn parser_accepts_query_file_and_splits_statements() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "queries.sql", + "-- leading comment\nSELECT 1 AS value;\nSELECT 2 AS value;\n\n# another comment\nWITH t AS (SELECT 3 AS value) SELECT * FROM t;\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "query_file.benchmark", + &format!("run {}\n", query_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run queries"), + &vec![ + "SELECT 1 AS value;".to_string(), + "SELECT 2 AS value;".to_string(), + "WITH t AS (SELECT 3 AS value) SELECT * FROM t;".to_string(), + ] + ); + } + + #[tokio::test] + async fn parser_accepts_replacements_in_query_file_path() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = + write_test_file(&temp_dir, "queries.sql", "SELECT 5 AS value;\n"); + let template_path = write_test_file( + &temp_dir, + "query_file_path_template.benchmark", + "run ${QUERY_PATH}\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "query_file_path_driver.benchmark", + &format!( + "template {}\nQUERY_PATH={}\n", + template_path.display(), + query_path.display() + ), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["SELECT 5 AS value;".to_string()] + ); + } + + #[tokio::test] + async fn parser_rejects_inline_sql_when_query_file_is_provided() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = + write_test_file(&temp_dir, "queries.sql", "SELECT 1 AS value;\n"); + let benchmark_path = write_test_file( + &temp_dir, + "query_file_with_inline_body.benchmark", + &format!("run {}\nSELECT 999 AS value;\n", query_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains( + result, + "run directive must use either a query file or inline SQL, not both", + ); + } + + #[tokio::test] + async fn parser_rejects_inline_sql_when_load_file_is_provided() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "load.sql", + "CREATE TABLE t AS SELECT 1 AS value;\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "load_file_with_inline_body.benchmark", + &format!( + "load {}\nCREATE TABLE u AS SELECT 2 AS value;\n", + query_path.display() + ), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains( + result, + "load directive must use either a query file or inline SQL, not both", + ); + } + + #[tokio::test] + async fn parser_rejects_inline_sql_when_init_file_is_provided() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "init.sql", + "CREATE VIEW v AS SELECT 1 AS value;\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "init_file_with_inline_body.benchmark", + &format!( + "init {}\nCREATE VIEW w AS SELECT 2 AS value;\n", + query_path.display() + ), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains( + result, + "init directive must use either a query file or inline SQL, not both", + ); + } + + #[tokio::test] + async fn parser_rejects_inline_sql_when_cleanup_file_is_provided() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file(&temp_dir, "cleanup.sql", "DROP TABLE t;\n"); + let benchmark_path = write_test_file( + &temp_dir, + "cleanup_file_with_inline_body.benchmark", + &format!("cleanup {}\nDROP TABLE u;\n", query_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains( + result, + "cleanup directive must use either a query file or inline SQL, not both", + ); + } + + #[tokio::test] + async fn parser_ignores_query_file_with_only_comments_and_blank_lines() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "queries.sql", + "# comment\n\n-- another comment\n\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "empty_query_file.benchmark", + &format!("run {}\n", query_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + assert!(!benchmark.queries().contains_key(&QueryDirective::Run)); + } + + #[tokio::test] + async fn parser_splits_query_file_with_windows_line_endings() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "queries.sql", + "SELECT 1 AS value;\r\nSELECT 2 AS value;\r\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "windows_query_file.benchmark", + &format!("run {}\n", query_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run queries"), + &vec![ + "SELECT 1 AS value;".to_string(), + "SELECT 2 AS value;".to_string() + ] + ); + } + + #[tokio::test] + async fn parser_rejects_unknown_command() { + assert_parse_error("wat\n", "Unrecognized command: wat").await; + } + + #[tokio::test] + async fn parser_rejects_assert_without_column_count() { + assert_parse_error( + "assert\nselect 1\n----\n1\n", + "assert must be followed by a column count", + ) + .await; + } + + #[tokio::test] + async fn parser_rejects_assert_without_result_separator() { + assert_parse_error( + "assert I\nselect 1\n1\n", + "assert must be followed by a query and a result (separated by ----)", + ) + .await; + } + + #[tokio::test] + async fn parser_rejects_result_query_without_separator() { + assert_parse_error( + "result_query I\nselect 1\n1\n", + "result_query must be followed by a query and a result (separated by ----)", + ) + .await; + } + + #[tokio::test] + async fn parser_rejects_result_query_with_wrong_column_count() { + assert_parse_error( + "result_query II\nselect 1\n----\n1\n", + "expected 2 values but got 1", + ) + .await; + } + + #[tokio::test] + async fn parser_rejects_multiple_result_queries() { + assert_parse_error( + "result_query I\nselect 1\n----\n1\n\nresult_query I\nselect 2\n----\n2\n", + "multiple results found", + ) + .await; + } + + #[tokio::test] + async fn parser_rejects_duplicate_run_directives() { + assert_parse_error("run\nselect 1\n\nrun\nselect 2\n", "Multiple calls to run") + .await; + } + + #[tokio::test] + async fn parser_accepts_multiple_load_directives() { + let benchmark = parse_benchmark( + "load\nCREATE TABLE t AS SELECT 1;\n\nload\nCREATE TABLE u AS SELECT 2;\n", + ) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Load) + .expect("load queries"), + &vec![ + "CREATE TABLE t AS SELECT 1;".to_string(), + "CREATE TABLE u AS SELECT 2;".to_string(), + ] + ); + } + + #[tokio::test] + async fn parser_accepts_multiple_init_directives() { + let benchmark = parse_benchmark( + "init\nCREATE VIEW v AS SELECT 1;\n\ninit\nCREATE VIEW w AS SELECT 2;\n", + ) + .await; + + let benchmark = benchmark.expect("benchmark should parse"); + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Init) + .expect("init queries"), + &vec![ + "CREATE VIEW v AS SELECT 1;".to_string(), + "CREATE VIEW w AS SELECT 2;".to_string(), + ] + ); + } + + #[tokio::test] + async fn parser_accepts_multiple_cleanup_directives() { + let benchmark = + parse_benchmark("cleanup\nDROP TABLE t;\n\ncleanup\nDROP TABLE u;\n") + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Cleanup) + .expect("cleanup queries"), + &vec!["DROP TABLE t;".to_string(), "DROP TABLE u;".to_string(),] + ); + } + + #[tokio::test] + async fn parser_rejects_missing_query_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let missing_path = temp_dir.path().join("missing.sql"); + let benchmark_path = write_test_file( + &temp_dir, + "missing_query_file.benchmark", + &format!("run {}\n", missing_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains(result, "Failed to read query file"); + } + + #[tokio::test] + async fn parser_rejects_template_with_invalid_parameter_assignment() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let template_path = + write_test_file(&temp_dir, "template.benchmark", "run\nselect 1\n"); + + let benchmark_path = write_test_file( + &temp_dir, + "template_driver.benchmark", + &format!("template {}\nINVALID\n", template_path.display()), + ); + + let ctx = SessionContext::new(); + let benchmark_path_string = benchmark_path.to_string_lossy().into_owned(); + let result = SqlBenchmark::new(&ctx, &benchmark_path_string, "/tmp").await; + + let error = result.expect_err("benchmark parsing should fail"); + let message = error.to_string(); + assert!( + message.contains("Expected a template parameter in the form of X=Y"), + "expected template parameter error, got {message:?}" + ); + } + + #[tokio::test] + async fn parser_rejects_metadata_and_result_directives_without_values() { + assert_parse_error("name\n", "name must be followed by a value").await; + assert_parse_error("group\n", "group must be followed by a value").await; + assert_parse_error("subgroup\n", "subgroup must be followed by a value").await; + assert_parse_error( + "expect_plan\n", + "expect_plan must be followed by a string to search in the physical plan", + ) + .await; + assert_parse_error("echo\n", "Echo requires an argument").await; + assert_parse_error( + "result\n", + "result must be followed by a path to a result file", + ) + .await; + assert_parse_error("include\n", "include requires a single argument").await; + assert_parse_error("template\n", "template requires a single template path") + .await; + } + + #[tokio::test] + async fn parser_rejects_include_and_template_with_too_many_arguments() { + assert_parse_error("include a b\n", "include requires a single argument").await; + assert_parse_error("template a b\n", "template requires a single template path") + .await; + } + + #[tokio::test] + async fn parser_rejects_missing_include_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let missing_path = temp_dir.path().join("missing_include.benchmark"); + let benchmark_path = write_test_file( + &temp_dir, + "missing_include_driver.benchmark", + &format!("include {}\n", missing_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains(result, "No such file"); + } + + #[tokio::test] + async fn parser_rejects_missing_template_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let missing_path = temp_dir.path().join("missing_template.benchmark"); + let benchmark_path = write_test_file( + &temp_dir, + "missing_template_driver.benchmark", + &format!("template {}\n", missing_path.display()), + ); + + let result = parse_benchmark_file(&benchmark_path).await; + + assert_result_error_contains(result, "No such file"); + } + + #[tokio::test] + async fn parser_uses_metadata_values_as_replacements() { + let benchmark = parse_benchmark( + r#" +name Q01 +group tpch +subgroup sf1 + +run +SELECT '${BENCH_NAME}', '${BENCH_GROUP}', '${BENCH_SUBGROUP}' +"#, + ) + .await + .expect("benchmark should parse"); + + assert_eq!( + benchmark + .queries() + .get(&QueryDirective::Run) + .expect("run query"), + &vec!["SELECT 'Q01', 'tpch', 'sf1'".to_string()] + ); + } + + #[tokio::test] + async fn parser_accepts_replacement_in_result_file_path() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = write_test_file(&temp_dir, "result.csv", "value\n1\n"); + let template_path = write_test_file( + &temp_dir, + "result_path_template.benchmark", + "result ${RESULT_PATH}\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "result_path_driver.benchmark", + &format!( + "template {}\nRESULT_PATH={}\n", + template_path.display(), + result_path.display() + ), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + let query = benchmark + .result_queries() + .first() + .expect("result query should be parsed"); + assert_eq!(query.path, Some(result_path.to_string_lossy().into_owned())); + assert_eq!(query.column_count, 0); + assert!(query.expected_result.is_empty()); + } + + #[tokio::test] + async fn parser_rejects_missing_replacement_in_result_file_path() { + assert_parse_error("result ${MISSING_RESULT_PATH}\n", "Missing value for key") + .await; + } + + #[tokio::test] + async fn parser_accepts_missing_result_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let missing_path = temp_dir.path().join("missing_result.csv"); + let benchmark_path = write_test_file( + &temp_dir, + "missing_result_file.benchmark", + &format!("result {}\n", missing_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + let query = benchmark + .result_queries() + .first() + .expect("result file should be parsed"); + assert_eq!( + query.path, + Some(missing_path.to_string_lossy().into_owned()) + ); + assert!(query.expected_result.is_empty()); + } + + #[tokio::test] + async fn parser_accepts_malformed_result_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = temp_dir.path().join("malformed_result.csv"); + fs::write(&result_path, [0xff]).expect("failed to write malformed result file"); + let benchmark_path = write_test_file( + &temp_dir, + "malformed_result_file.benchmark", + &format!("result {}\n", result_path.display()), + ); + + let benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + + let query = benchmark + .result_queries() + .first() + .expect("result file should be parsed"); + assert_eq!(query.path, Some(result_path.to_string_lossy().into_owned())); + assert!(query.expected_result.is_empty()); + } + + // Lifecycle tests cover initialization, assertions, and cleanup execution. + + #[tokio::test] + async fn initialize_executes_load_before_init_and_is_idempotent() { + let mut benchmark = parse_benchmark( + r#" +load +CREATE TABLE t AS SELECT 1 AS value; + +load +CREATE TABLE u AS SELECT value + 1 AS value FROM t; + +init +CREATE TABLE v AS SELECT value + 1 AS value FROM u; + +init +CREATE TABLE initialized AS SELECT value + 1 AS value FROM v; + +run +SELECT value FROM initialized; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark + .initialize(&ctx) + .await + .expect("initialize should succeed"); + benchmark + .initialize(&ctx) + .await + .expect("second initialize should be a no-op"); + + assert!(benchmark.is_loaded()); + + let rows = ctx + .sql("SELECT value FROM initialized") + .await + .expect("query should plan") + .collect() + .await + .expect("query should run"); + + assert_eq!(format_record_batches(&rows).unwrap(), vec![vec!["4"]]); + } + + #[tokio::test] + async fn initialize_rejects_benchmark_without_run_query() { + let mut benchmark = parse_benchmark( + r#" +load +CREATE TABLE t AS SELECT 1 AS value; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.initialize(&ctx).await, + "Invalid benchmark file: no \"run\" query specified", + ); + } + + #[tokio::test] + async fn initialize_propagates_load_query_failures() { + let mut benchmark = parse_benchmark( + r#" +load +CREATE TABLE t AS SELECT * FROM missing_load_table; + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.initialize(&ctx).await, + "missing_load_table", + ); + } + + #[tokio::test] + async fn initialize_propagates_init_query_failures() { + let mut benchmark = parse_benchmark( + r#" +init +CREATE TABLE t AS SELECT * FROM missing_init_table; + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.initialize(&ctx).await, + "missing_init_table", + ); + } + + #[tokio::test] + async fn cleanup_executes_cleanup_queries() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 1; + +cleanup +CREATE TABLE cleanup_marker_a AS SELECT 7 AS value; + +cleanup +CREATE TABLE cleanup_marker_b AS SELECT value + 1 AS value FROM cleanup_marker_a; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.cleanup(&ctx).await.expect("cleanup should run"); + + let rows = ctx + .sql("SELECT value FROM cleanup_marker_b") + .await + .expect("query should plan") + .collect() + .await + .expect("query should run"); + assert_eq!(format_record_batches(&rows).unwrap(), vec![vec!["8"]]); + } + + #[tokio::test] + async fn cleanup_propagates_query_failures() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 1; + +cleanup +SELECT * FROM missing_cleanup_table; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.cleanup(&ctx).await, + "missing_cleanup_table", + ); + } + + #[tokio::test] + async fn assert_executes_assert_queries_successfully() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT 1 AS value +---- +1 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.assert(&ctx).await.expect("assert should pass"); + } + + #[tokio::test] + async fn assert_accepts_null_expected_for_empty_actual() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT '' AS value +---- +NULL + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.assert(&ctx).await.expect("assert should pass"); + } + + #[tokio::test] + async fn assert_accepts_empty_marker_for_empty_actual() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT '' AS value +---- +(empty) + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.assert(&ctx).await.expect("assert should pass"); + } + + #[tokio::test] + async fn assert_accepts_empty_marker_for_null_actual() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT CAST(NULL AS VARCHAR) AS value +---- +(empty) + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.assert(&ctx).await.expect("assert should pass"); + } + + #[tokio::test] + async fn assert_succeeds_with_zero_actual_and_expected_rows() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT 1 AS value WHERE false +---- + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.assert(&ctx).await.expect("assert should pass"); + } + + #[tokio::test] + async fn assert_propagates_query_failures() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT * FROM missing_assert_table +---- +1 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.assert(&ctx).await, + "missing_assert_table", + ); + } + + #[tokio::test] + async fn assert_reports_row_count_mismatch() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT 1 AS value +---- +1 +2 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.assert(&ctx).await, + "expected 2 rows but got 1", + ); + } + + #[tokio::test] + async fn assert_reports_column_count_mismatch() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT 1 AS a, 2 AS b +---- +1 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.assert(&ctx).await, + "expected 1 columns but got 2", + ); + } + + #[tokio::test] + async fn assert_reports_value_mismatch() { + let mut benchmark = parse_benchmark( + r#" +assert I +SELECT 1 AS value +---- +2 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.assert(&ctx).await, + "expected value \"2\" but got value \"1\"", + ); + } + + // Run tests cover result buffering and physical-plan expectations. + + #[tokio::test] + async fn run_saves_uppercase_select_results() { + let mut benchmark = parse_benchmark("run\nSELECT 1 AS value\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_eq!(formatted_last_results(&benchmark), vec![vec!["1"]]); + } + + #[tokio::test] + async fn run_saves_with_query_results() { + let mut benchmark = + parse_benchmark("run\nWITH t AS (SELECT 3 AS value) SELECT value FROM t\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_eq!(formatted_last_results(&benchmark), vec![vec!["3"]]); + } + + #[tokio::test] + async fn run_only_keeps_last_select_or_with_result() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let query_path = write_test_file( + &temp_dir, + "queries.sql", + "SELECT 1 AS value;\nSELECT 2 AS value;\nWITH t AS (SELECT 3 AS value) SELECT value FROM t;\n", + ); + let benchmark_path = write_test_file( + &temp_dir, + "run_file.benchmark", + &format!("run {}\n", query_path.display()), + ); + let mut benchmark = parse_benchmark_file(&benchmark_path) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_eq!(formatted_last_results(&benchmark), vec![vec!["3"]]); + } + + #[tokio::test] + async fn run_inline_multi_statement_only_keeps_last_select_or_with_result() { + let mut benchmark = parse_benchmark( + "run\nCREATE TABLE t AS SELECT 1 AS value;\nSELECT 2 AS value;\nWITH u AS (SELECT 3 AS value) SELECT value FROM u;\n", + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_eq!(formatted_last_results(&benchmark), vec![vec!["3"]]); + } + + #[tokio::test] + async fn run_does_not_save_results_for_non_select_statement() { + let mut benchmark = + parse_benchmark("run\nCREATE TABLE run_created AS SELECT 1 AS value;\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert!( + benchmark + .last_results + .as_ref() + .expect("last results should be set") + .is_empty() + ); + } + + #[tokio::test] + async fn run_propagates_query_failures_when_buffering_results() { + let mut benchmark = parse_benchmark("run\nSELECT * FROM missing_run_table\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.run(&ctx, true).await, + "missing_run_table", + ); + } + + #[tokio::test] + async fn run_propagates_query_failures_when_streaming_results() { + let mut benchmark = parse_benchmark("run\nSELECT * FROM missing_stream_table\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.run(&ctx, false).await, + "missing_stream_table", + ); + } + + #[tokio::test] + async fn run_rejects_missing_expect_plan_for_buffered_and_streaming_modes() { + let ctx = SessionContext::new(); + let benchmark_text = "expect_plan definitely_not_in_plan\nrun\nSELECT 1\n"; + + let mut buffered = parse_benchmark(benchmark_text) + .await + .expect("benchmark should parse"); + assert_result_error_contains( + buffered.run(&ctx, true).await, + "does not contain the expected string 'definitely_not_in_plan'", + ); + + let mut streaming = parse_benchmark(benchmark_text) + .await + .expect("benchmark should parse"); + assert_result_error_contains( + streaming.run(&ctx, false).await, + "does not contain the expected string 'definitely_not_in_plan'", + ); + } + + #[tokio::test] + async fn run_accepts_matching_expect_plan_for_buffered_and_streaming_modes() { + let ctx = SessionContext::new(); + let benchmark_text = "expect_plan PlaceholderRowExec\nrun\nSELECT 1\n"; + + let mut buffered = parse_benchmark(benchmark_text) + .await + .expect("benchmark should parse"); + buffered + .run(&ctx, true) + .await + .expect("buffered run should accept matching plan"); + assert_eq!(formatted_last_results(&buffered), vec![vec!["1"]]); + + let mut streaming = parse_benchmark(benchmark_text) + .await + .expect("benchmark should parse"); + streaming + .run(&ctx, false) + .await + .expect("streaming run should accept matching plan"); + } + + // Verification tests cover result_query and persisted-result comparison paths. + + #[tokio::test] + async fn verify_without_result_query_returns_ok() { + let mut benchmark = parse_benchmark("run\nSELECT 1 AS value\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.verify(&ctx).await.expect("verify should pass"); + } + + #[tokio::test] + async fn verify_errors_when_benchmark_has_not_run() { + let mut benchmark = parse_benchmark( + r#" +result_query I +SELECT 1 AS value +---- +1 + +run +SELECT 1; +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.verify(&ctx).await, + "No results available for verification. Run the benchmark first.", + ); + } + + #[tokio::test] + async fn verify_uses_last_results_for_result_file_entries() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = write_test_file(&temp_dir, "result.csv", "value\n1\n"); + let mut benchmark = parse_benchmark(&format!( + "result {}\n\nrun\nSELECT 1 AS value\n", + result_path.display() + )) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + benchmark.verify(&ctx).await.expect("verify should pass"); + } + + #[tokio::test] + async fn verify_uses_last_results_for_zero_row_result_file_entries() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = write_test_file(&temp_dir, "result.csv", "value\n"); + let mut benchmark = parse_benchmark(&format!( + "result {}\n\nrun\nSELECT 1 AS value WHERE false\n", + result_path.display() + )) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + benchmark.verify(&ctx).await.expect("verify should pass"); + } + + #[tokio::test] + async fn verify_rejects_missing_result_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let missing_path = temp_dir.path().join("missing_result.csv"); + let mut benchmark = parse_benchmark(&format!( + "result {}\n\nrun\nSELECT 1 AS value\n", + missing_path.display() + )) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_result_error_contains(benchmark.verify(&ctx).await, "missing_result.csv"); + } + + #[tokio::test] + async fn verify_rejects_malformed_result_file() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let result_path = temp_dir.path().join("malformed_result.csv"); + fs::write(&result_path, [0xff]).expect("failed to write malformed result file"); + let mut benchmark = parse_benchmark(&format!( + "result {}\n\nrun\nSELECT 1 AS value\n", + result_path.display() + )) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_result_error_contains(benchmark.verify(&ctx).await, "CSV"); + } + + #[tokio::test] + async fn verify_executes_result_query_instead_of_last_results() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 100 AS value + +result_query I +SELECT 1 AS value +---- +1 +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + benchmark.verify(&ctx).await.expect("verify should pass"); + } + + #[tokio::test] + async fn verify_propagates_result_query_failures() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 1 AS value + +result_query I +SELECT * FROM missing_verify_table +---- +1 +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + assert_result_error_contains( + benchmark.verify(&ctx).await, + "missing_verify_table", + ); + } + + #[tokio::test] + async fn verify_reports_result_mismatch_context() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 1 AS value + +result_query I +SELECT 1 AS value +---- +2 +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.run(&ctx, true).await.expect("run should succeed"); + + let error = benchmark + .verify(&ctx) + .await + .expect_err("verify should fail"); + let message = error.to_string(); + assert!( + message.contains("row 1, column 1") + && message.contains("expected value \"2\"") + && message.contains("got value \"1\""), + "unexpected error: {message}" + ); + } + + // Persistence tests cover CSV writing and persist-time error paths. + + #[tokio::test] + async fn persist_without_result_query_returns_ok() { + let mut benchmark = parse_benchmark("run\nSELECT 1 AS value\n") + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + benchmark.persist(&ctx).await.expect("persist should pass"); + } + + #[tokio::test] + async fn persist_rejects_result_query_without_file_path() { + let mut benchmark = parse_benchmark( + r#" +run +SELECT 1 AS value + +result_query I +SELECT 1 AS value +---- +1 +"#, + ) + .await + .expect("benchmark should parse"); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.persist(&ctx).await, + "Unable to persist results from query", + ); + } + + #[tokio::test] + async fn persist_rejects_run_without_saved_result_batches() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let output_path = temp_dir.path().join("persisted"); + let mut benchmark = + parse_benchmark("run\nCREATE TABLE persist_source AS SELECT 1 AS value;\n") + .await + .expect("benchmark should parse"); + benchmark.result_queries.push(BenchmarkQuery { + path: Some(output_path.to_string_lossy().into_owned()), + query: String::new(), + column_count: 1, + expected_result: vec![], + }); + let ctx = SessionContext::new(); + + assert_result_error_contains( + benchmark.persist(&ctx).await, + "Results should be loaded", + ); + } + + #[tokio::test] + async fn persist_writes_header_and_pipe_delimited_rows() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let output_path = temp_dir.path().join("persisted"); + let mut benchmark = parse_benchmark("run\nSELECT 1 AS a, 'one' AS b\n") + .await + .expect("benchmark should parse"); + benchmark.result_queries.push(BenchmarkQuery { + path: Some(output_path.to_string_lossy().into_owned()), + query: String::new(), + column_count: 2, + expected_result: vec![], + }); + let ctx = SessionContext::new(); + + benchmark.persist(&ctx).await.expect("persist should pass"); + + let contents = read_all_files_in_dir(&output_path); + assert!( + contents.contains("a|b\n") && contents.contains("1|one\n"), + "unexpected persisted contents: {contents:?}" + ); + } + + #[tokio::test] + async fn persist_writes_header_for_zero_row_select_results() { + let temp_dir = tempdir().expect("failed to create benchmark test directory"); + let output_path = temp_dir.path().join("persisted_empty"); + let mut benchmark = parse_benchmark("run\nSELECT 1 AS value WHERE false\n") + .await + .expect("benchmark should parse"); + benchmark.result_queries.push(BenchmarkQuery { + path: Some(output_path.to_string_lossy().into_owned()), + query: String::new(), + column_count: 1, + expected_result: vec![], + }); + let ctx = SessionContext::new(); + + benchmark.persist(&ctx).await.expect("persist should pass"); + + let contents = read_all_files_in_dir(&output_path); + assert!( + contents.contains("value\n"), + "unexpected persisted contents: {contents:?}" + ); + } + + // Path helper tests cover group derivation from benchmark file paths. + + #[test] + fn parse_group_from_path_returns_group_under_benchmark_directory() { + let group = parse_group_from_path( + Path::new("sql_benchmarks/tpch/benchmarks/q01.benchmark"), + Path::new("sql_benchmarks"), + ); + + assert_eq!(group, "tpch"); + } + + #[test] + fn parse_group_from_path_matches_benchmark_directory_case_insensitively() { + let group = parse_group_from_path( + Path::new("/tmp/SQL_BENCHMARKS/Tpch/benchmarks/q01.benchmark"), + Path::new("sql_benchmarks"), + ); + + assert_eq!(group, "Tpch"); + } + + #[test] + fn parse_group_from_path_handles_relative_and_absolute_paths() { + let relative = parse_group_from_path( + Path::new("sql_benchmarks/h2o/q01.benchmark"), + Path::new("sql_benchmarks"), + ); + let absolute = parse_group_from_path( + Path::new("/tmp/sql_benchmarks/imdb/q01.benchmark"), + Path::new("sql_benchmarks"), + ); + + assert_eq!(relative, "h2o"); + assert_eq!(absolute, "imdb"); + } + + #[test] + fn parse_group_from_path_pins_fallback_for_paths_outside_benchmark_directory() { + let group = parse_group_from_path( + Path::new("outside/group/q01.benchmark"), + Path::new("sql_benchmarks"), + ); + + assert_eq!(group, "outside"); + } + + #[test] + fn path_ends_with_ignore_ascii_case_matches_component_suffixes() { + assert!(path_ends_with_ignore_ascii_case( + Path::new("/tmp/SQL_BENCHMARKS"), + Path::new("sql_benchmarks") + )); + assert!(!path_ends_with_ignore_ascii_case( + Path::new("/tmp/sql_benchmarks_extra"), + Path::new("sql_benchmarks") + )); + } +} diff --git a/datafusion/core/tests/schema_adapter/mod.rs b/benchmarks/src/tpcds/mod.rs similarity index 95% rename from datafusion/core/tests/schema_adapter/mod.rs rename to benchmarks/src/tpcds/mod.rs index 2f81a43f4736e..4829eb9fd348a 100644 --- a/datafusion/core/tests/schema_adapter/mod.rs +++ b/benchmarks/src/tpcds/mod.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. -mod schema_adapter_integration_tests; +mod run; +pub use run::RunOpt; diff --git a/benchmarks/src/tpcds/run.rs b/benchmarks/src/tpcds/run.rs new file mode 100644 index 0000000000000..f7ef6991515da --- /dev/null +++ b/benchmarks/src/tpcds/run.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, plan_err}; + +use clap::Args; +use log::info; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; +pub const TPCDS_QUERY_START_ID: usize = 1; +pub const TPCDS_QUERY_END_ID: usize = 99; + +pub const TPCDS_TABLES: &[&str] = &[ + "call_center", + "customer_address", + "household_demographics", + "promotion", + "store_sales", + "web_page", + "catalog_page", + "customer_demographics", + "income_band", + "reason", + "store", + "web_returns", + "catalog_returns", + "customer", + "inventory", + "ship_mode", + "time_dim", + "web_sales", + "catalog_sales", + "date_dim", + "item", + "store_returns", + "warehouse", + "web_site", +]; + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(base_query_path: &str, query: usize) -> Result> { + if query > 0 && query < 100 { + let filename = format!("{base_query_path}/{query}.sql"); + let mut errors = vec![]; + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + + plan_err!("invalid query. Could not find query: {:?}", errors) + } else { + plan_err!("invalid query. Expected value between 1 and 99") + } +} + +/// Run the tpcds benchmark. +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[arg(short, long)] + pub query: Option, + + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// Path to data files + #[arg(required = true, short = 'p', long = "path")] + path: PathBuf, + + /// Path to query files + #[arg(required = true, short = 'Q', long = "query_path")] + query_path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[arg(short = 'm', long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[arg(short = 'o', long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[arg(short = 'S', long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, + + /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join + /// False by default. + #[arg( + short = 'w', + long = "enable_piecewise_merge_join", + default_value = "false" + )] + enable_piecewise_merge_join: BoolDefaultTrue, + + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[arg(short = 't', long = "sorted")] + sorted: bool, + + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => TPCDS_QUERY_START_ID..=TPCDS_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + let mut config = self + .common + .config()? + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + config.options_mut().optimizer.enable_piecewise_merge_join = + self.enable_piecewise_merge_join; + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + // register tables + self.register_tables(&ctx).await?; + + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + + let sql = &get_query_sql(self.query_path.to_str().unwrap(), query_id)?; + + if self.common.debug { + println!("=== SQL for query {query_id} ===\n{}\n", sql.join(";\n")); + } + + for i in 0..self.iterations() { + let start = Instant::now(); + + // query 15 is special, with 3 statements. the second statement is the one from which we + // want to capture the results + let mut result = vec![]; + + for query in sql { + result = self.execute_query(ctx, query).await?; + } + + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in TPCDS_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let target_partitions = self.partitions(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}.parquet"); + + // Check if the file exists + if !std::path::Path::new(&path).exists() { + eprintln!("Warning registering {table}: Table file does not exist: {path}"); + } + + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + + let table_path = ListingTableUrl::parse(path)?; + let options = ListingOptions::new(Arc::new(format)) + .with_file_extension(DEFAULT_PARQUET_EXTENSION) + .with_target_partitions(target_partitions) + .with_collect_stat(state.config().collect_statistics()); + let schema = options.infer_schema(&state, &table_path).await?; + + if self.common.debug { + println!( + "Inferred schema from {table_path} for table '{table}':\n{schema:#?}\n" + ); + } + + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or_else(get_available_parallelism) + } +} diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs deleted file mode 100644 index 5219e09cd3052..0000000000000 --- a/benchmarks/src/tpch/convert.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::logical_expr::select_expr::SelectExpr; -use datafusion_common::instant::Instant; -use std::fs; -use std::path::{Path, PathBuf}; - -use datafusion::common::not_impl_err; - -use super::get_tbl_tpch_table_schema; -use super::TPCH_TABLES; -use datafusion::error::Result; -use datafusion::prelude::*; -use parquet::basic::Compression; -use parquet::file::properties::WriterProperties; -use structopt::StructOpt; - -/// Convert tpch .slt files to .parquet or .csv files -#[derive(Debug, StructOpt)] -pub struct ConvertOpt { - /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] - input_path: PathBuf, - - /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] - output_path: PathBuf, - - /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] - file_format: String, - - /// Compression to use when writing Parquet files - #[structopt(short = "c", long = "compression", default_value = "zstd")] - compression: String, - - /// Number of partitions to produce - #[structopt(short = "n", long = "partitions", default_value = "1")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, - - /// Sort each table by its first column in ascending order. - #[structopt(short = "t", long = "sort")] - sort: bool, -} - -impl ConvertOpt { - pub async fn run(self) -> Result<()> { - let compression = self.compression()?; - - let input_path = self.input_path.to_str().unwrap(); - let output_path = self.output_path.to_str().unwrap(); - - let output_root_path = Path::new(output_path); - for table in TPCH_TABLES { - let start = Instant::now(); - let schema = get_tbl_tpch_table_schema(table); - let key_column_name = schema.fields()[0].name(); - - let input_path = format!("{input_path}/{table}.tbl"); - let options = CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .delimiter(b'|') - .file_extension(".tbl"); - let options = if self.sort { - // indicated that the file is already sorted by its first column to speed up the conversion - options - .file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) - } else { - options - }; - - let config = SessionConfig::new().with_batch_size(self.batch_size); - let ctx = SessionContext::new_with_config(config); - - // build plan to read the TBL file - let mut csv = ctx.read_csv(&input_path, options).await?; - - // Select all apart from the padding column - let selection = csv - .schema() - .iter() - .take(schema.fields.len() - 1) - .map(Expr::from) - .map(SelectExpr::from) - .collect::>(); - - csv = csv.select(selection)?; - // optionally, repartition the file - let partitions = self.partitions; - if partitions > 1 { - csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? - } - let csv = if self.sort { - csv.sort_by(vec![col(key_column_name)])? - } else { - csv - }; - - // create the physical plan - let csv = csv.create_physical_plan().await?; - - let output_path = output_root_path.join(table); - let output_path = output_path.to_str().unwrap().to_owned(); - fs::create_dir_all(&output_path)?; - println!( - "Converting '{}' to {} files in directory '{}'", - &input_path, self.file_format, &output_path - ); - match self.file_format.as_str() { - "csv" => ctx.write_csv(csv, output_path).await?, - "parquet" => { - let props = WriterProperties::builder() - .set_compression(compression) - .build(); - ctx.write_parquet(csv, output_path, Some(props)).await? - } - other => { - return not_impl_err!("Invalid output format: {other}"); - } - } - println!("Conversion completed in {} ms", start.elapsed().as_millis()); - } - - Ok(()) - } - - /// return the compression method to use when writing parquet - fn compression(&self) -> Result { - Ok(match self.compression.as_str() { - "none" => Compression::UNCOMPRESSED, - "snappy" => Compression::SNAPPY, - "brotli" => Compression::BROTLI(Default::default()), - "gzip" => Compression::GZIP(Default::default()), - "lz4" => Compression::LZ4, - "lz0" => Compression::LZO, - "zstd" => Compression::ZSTD(Default::default()), - other => { - return not_impl_err!("Invalid compression format: {other}"); - } - }) - } -} diff --git a/benchmarks/src/tpch/mod.rs b/benchmarks/src/tpch/mod.rs index 233ea94a05c1a..08cedc0e5b4c3 100644 --- a/benchmarks/src/tpch/mod.rs +++ b/benchmarks/src/tpch/mod.rs @@ -27,15 +27,13 @@ use std::fs; mod run; pub use run::RunOpt; -mod convert; -pub use convert::ConvertOpt; - pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; pub const TPCH_QUERY_START_ID: usize = 1; pub const TPCH_QUERY_END_ID: usize = 22; +const TPCH_Q11_FRACTION_SENTINEL: &str = "0.0001 /* __TPCH_Q11_FRACTION__ */"; /// The `.tbl` file contains a trailing column pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { @@ -142,6 +140,21 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { /// Get the SQL statements from the specified query file pub fn get_query_sql(query: usize) -> Result> { + get_query_sql_for_scale_factor(query, 1.0) +} + +/// Get the SQL statements from the specified query file using the provided scale factor for +/// TPC-H substitutions such as Q11 FRACTION. +pub fn get_query_sql_for_scale_factor( + query: usize, + scale_factor: f64, +) -> Result> { + if !(scale_factor.is_finite() && scale_factor > 0.0) { + return plan_err!( + "invalid scale factor. Expected a positive finite value, got {scale_factor}" + ); + } + if query > 0 && query < 23 { let possibilities = vec![ format!("queries/q{query}.sql"), @@ -151,6 +164,7 @@ pub fn get_query_sql(query: usize) -> Result> { for filename in possibilities { match fs::read_to_string(&filename) { Ok(contents) => { + let contents = customize_query_sql(query, contents, scale_factor)?; return Ok(contents .split(';') .map(|s| s.trim()) @@ -167,6 +181,27 @@ pub fn get_query_sql(query: usize) -> Result> { } } +fn customize_query_sql( + query: usize, + contents: String, + scale_factor: f64, +) -> Result { + if query != 11 { + return Ok(contents); + } + + if !contents.contains(TPCH_Q11_FRACTION_SENTINEL) { + return plan_err!( + "invalid query 11. Missing fraction marker {TPCH_Q11_FRACTION_SENTINEL}" + ); + } + + Ok(contents.replace( + TPCH_Q11_FRACTION_SENTINEL, + &format!("(0.0001 / {scale_factor})"), + )) +} + pub const QUERY_LIMIT: [Option; 22] = [ None, Some(100), @@ -191,3 +226,51 @@ pub const QUERY_LIMIT: [Option; 22] = [ Some(100), None, ]; + +#[cfg(test)] +mod tests { + use super::{get_query_sql, get_query_sql_for_scale_factor}; + use datafusion::error::Result; + + fn get_single_query(query: usize) -> Result { + let mut queries = get_query_sql(query)?; + assert_eq!(queries.len(), 1); + Ok(queries.remove(0)) + } + + fn get_single_query_for_scale_factor( + query: usize, + scale_factor: f64, + ) -> Result { + let mut queries = get_query_sql_for_scale_factor(query, scale_factor)?; + assert_eq!(queries.len(), 1); + Ok(queries.remove(0)) + } + + #[test] + fn q11_uses_scale_factor_substitution() -> Result<()> { + let sf1_sql = get_single_query(11)?; + assert!(sf1_sql.contains("(0.0001 / 1)")); + + let sf01_sql = get_single_query_for_scale_factor(11, 0.1)?; + assert!(sf01_sql.contains("(0.0001 / 0.1)")); + + let sf10_sql = get_single_query_for_scale_factor(11, 10.0)?; + assert!(sf10_sql.contains("(0.0001 / 10)")); + + let sf30_sql = get_single_query_for_scale_factor(11, 30.0)?; + assert!(sf30_sql.contains("(0.0001 / 30)")); + assert!(!sf10_sql.contains("__TPCH_Q11_FRACTION__")); + Ok(()) + } + + #[test] + fn interval_queries_use_interval_arithmetic() -> Result<()> { + assert!(get_single_query(5)?.contains("date '1994-01-01' + interval '1' year")); + assert!(get_single_query(6)?.contains("date '1994-01-01' + interval '1' year")); + assert!(get_single_query(10)?.contains("date '1993-10-01' + interval '3' month")); + assert!(get_single_query(12)?.contains("date '1994-01-01' + interval '1' year")); + assert!(get_single_query(14)?.contains("date '1995-09-01' + interval '1' month")); + Ok(()) + } +} diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index cc59b78030360..ec7aa8c554a28 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -15,20 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; use super::{ - get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_QUERY_END_ID, - TPCH_QUERY_START_ID, TPCH_TABLES, + TPCH_QUERY_END_ID, TPCH_QUERY_START_ID, TPCH_TABLES, get_query_sql_for_scale_factor, + get_tbl_tpch_table_schema, get_tpch_table_schema, }; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::common::exec_err; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -41,8 +42,8 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use clap::Args; use log::info; -use structopt::StructOpt; // hack to avoid `default_value is meaningless for bool` errors type BoolDefaultTrue = bool; @@ -56,46 +57,51 @@ type BoolDefaultTrue = bool; /// [1]: http://www.tpc.org/tpch/ /// [2]: https://github.com/databricks/tpch-dbgen.git /// [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number. If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, + /// TPC-H scale factor used for query substitutions such as Q11 FRACTION. + /// If omitted, the benchmark tries to infer it from paths like `.../tpch_sf10/...`. + #[arg(long)] + scale_factor: Option, + /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] + #[arg(short = 'f', long = "format", default_value = "csv")] file_format: String, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Whether to disable collection of statistics (and cost based optimizations) or not. - #[structopt(short = "S", long = "disable-statistics")] + #[arg(short = 'S', long = "disable-statistics")] disable_statistics: bool, /// If true then hash join used, if false then sort merge join /// True by default. - #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join - /// True by default. - #[structopt( - short = "j", + /// False by default. + #[arg( + short = 'w', long = "enable_piecewise_merge_join", default_value = "false" )] @@ -103,8 +109,12 @@ pub struct RunOpt { /// Mark the first column of each table as sorted in ascending order. /// The tables should have been created with the `--sort` option for this to have any effect. - #[structopt(short = "t", long = "sorted")] + #[arg(short = 't', long = "sorted")] sorted: bool, + + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, } impl RunOpt { @@ -123,14 +133,17 @@ impl RunOpt { config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; config.options_mut().optimizer.enable_piecewise_merge_join = self.enable_piecewise_merge_join; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // register tables self.register_tables(&ctx).await?; + let scale_factor = self.scale_factor()?; for query_id in query_range { benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(query_id, &ctx).await; + let query_run = self.benchmark_query(query_id, scale_factor, &ctx).await; match query_run { Ok(query_results) => { for iter in query_results { @@ -151,13 +164,14 @@ impl RunOpt { async fn benchmark_query( &self, query_id: usize, + scale_factor: f64, ctx: &SessionContext, ) -> Result> { let mut millis = vec![]; // run benchmark let mut query_results = vec![]; - let sql = &get_query_sql(query_id)?; + let sql = &get_query_sql_for_scale_factor(query_id, scale_factor)?; for i in 0..self.iterations() { let start = Instant::now(); @@ -340,6 +354,82 @@ impl RunOpt { .partitions .unwrap_or_else(get_available_parallelism) } + + fn scale_factor(&self) -> Result { + resolve_scale_factor(self.scale_factor, &self.path) + } +} + +fn resolve_scale_factor(scale_factor: Option, path: &Path) -> Result { + let scale_factor = scale_factor + .or_else(|| infer_scale_factor_from_path(path)) + .unwrap_or(1.0); + + if scale_factor.is_finite() && scale_factor > 0.0 { + Ok(scale_factor) + } else { + exec_err!( + "Invalid TPC-H scale factor {scale_factor}. Expected a positive finite value" + ) + } +} + +fn infer_scale_factor_from_path(path: &Path) -> Option { + path.iter().find_map(|component| { + component + .to_str()? + .strip_prefix("tpch_sf")? + .parse::() + .ok() + }) +} + +#[cfg(test)] +mod scale_factor_tests { + use std::path::Path; + + use super::{infer_scale_factor_from_path, resolve_scale_factor}; + use datafusion::error::Result; + + #[test] + fn uses_explicit_scale_factor_when_provided() -> Result<()> { + let scale_factor = + resolve_scale_factor(Some(30.0), Path::new("benchmarks/data/tpch_sf10"))?; + assert_eq!(scale_factor, 30.0); + Ok(()) + } + + #[test] + fn infers_scale_factor_from_standard_tpch_path() -> Result<()> { + let scale_factor = + resolve_scale_factor(None, Path::new("benchmarks/data/tpch_sf10"))?; + assert_eq!(scale_factor, 10.0); + assert_eq!( + infer_scale_factor_from_path(Path::new("benchmarks/data/tpch_sf0.1")), + Some(0.1) + ); + Ok(()) + } + + #[test] + fn defaults_to_sf1_when_path_has_no_scale_factor() -> Result<()> { + let scale_factor = resolve_scale_factor(None, Path::new("benchmarks/data"))?; + assert_eq!(scale_factor, 1.0); + Ok(()) + } + + #[test] + fn rejects_invalid_scale_factors() { + assert!(resolve_scale_factor(Some(0.0), Path::new("benchmarks/data")).is_err()); + assert!(resolve_scale_factor(Some(-1.0), Path::new("benchmarks/data")).is_err()); + assert!( + resolve_scale_factor(Some(f64::NAN), Path::new("benchmarks/data")).is_err() + ); + assert!( + resolve_scale_factor(Some(f64::INFINITY), Path::new("benchmarks/data")) + .is_err() + ); + } } #[cfg(test)] @@ -380,11 +470,13 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), common, path: PathBuf::from(path.to_string()), + scale_factor: Some(1.0), file_format: "tbl".to_string(), mem_table: false, output_path: None, @@ -392,9 +484,10 @@ mod tests { prefer_hash_join: true, enable_piecewise_merge_join: false, sorted: false, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; - let queries = get_query_sql(query)?; + let queries = crate::tpch::get_query_sql(query)?; for query in queries { let plan = ctx.sql(&query).await?; let plan = plan.into_optimized_plan()?; @@ -418,11 +511,13 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), common, path: PathBuf::from(path.to_string()), + scale_factor: Some(1.0), file_format: "tbl".to_string(), mem_table: false, output_path: None, @@ -430,9 +525,10 @@ mod tests { prefer_hash_join: true, enable_piecewise_merge_join: false, sorted: false, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; - let queries = get_query_sql(query)?; + let queries = crate::tpch::get_query_sql(query)?; for query in queries { let plan = ctx.sql(&query).await?; let plan = plan.create_physical_plan().await?; diff --git a/benchmarks/src/util/latency_object_store.rs b/benchmarks/src/util/latency_object_store.rs new file mode 100644 index 0000000000000..9ef8d1b78b751 --- /dev/null +++ b/benchmarks/src/util/latency_object_store.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An ObjectStore wrapper that adds simulated S3-like latency to get and list operations. +//! +//! Cycles through a fixed latency distribution inspired by real S3 performance: +//! - P50: ~30ms +//! - P75-P90: ~100-120ms +//! - P99: ~150-200ms + +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use futures::StreamExt; +use futures::stream::BoxStream; +use object_store::path::Path; +use object_store::{ + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, +}; + +/// GET latency distribution, inspired by S3 latencies. +/// Deterministic but shuffled to avoid artificial patterns. +/// 20 values: 11x P50 (~25-35ms), 5x P75-P90 (~70-110ms), 2x P95 (~120-150ms), 2x P99 (~180-200ms) +/// Sorted: 25,25,28,28,30,30,30,30,32,32,35, 70,85,100,100,110, 130,150, 180,200 +/// P50≈32ms, P90≈110ms, P99≈200ms +const GET_LATENCIES_MS: &[u64] = &[ + 30, 100, 25, 85, 32, 200, 28, 130, 35, 70, 30, 150, 30, 110, 28, 180, 32, 25, 100, 30, +]; + +/// LIST latency distribution, generally higher than GET. +/// 20 values: 11x P50 (~40-70ms), 5x P75-P90 (~120-180ms), 2x P95 (~200-250ms), 2x P99 (~300-400ms) +/// Sorted: 40,40,50,50,55,55,60,60,65,65,70, 120,140,160,160,180, 210,250, 300,400 +/// P50≈65ms, P90≈180ms, P99≈400ms +const LIST_LATENCIES_MS: &[u64] = &[ + 55, 160, 40, 140, 65, 400, 50, 210, 70, 120, 60, 250, 55, 180, 50, 300, 65, 40, 160, + 60, +]; + +/// An ObjectStore wrapper that injects simulated latency on get and list calls. +#[derive(Debug)] +pub struct LatencyObjectStore { + inner: T, + get_counter: AtomicUsize, + list_counter: AtomicUsize, +} + +impl LatencyObjectStore { + pub fn new(inner: T) -> Self { + Self { + inner, + get_counter: AtomicUsize::new(0), + list_counter: AtomicUsize::new(0), + } + } + + fn next_get_latency(&self) -> Duration { + let idx = + self.get_counter.fetch_add(1, Ordering::Relaxed) % GET_LATENCIES_MS.len(); + Duration::from_millis(GET_LATENCIES_MS[idx]) + } + + fn next_list_latency(&self) -> Duration { + let idx = + self.list_counter.fetch_add(1, Ordering::Relaxed) % LIST_LATENCIES_MS.len(); + Duration::from_millis(LIST_LATENCIES_MS[idx]) + } +} + +impl fmt::Display for LatencyObjectStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "LatencyObjectStore({})", self.inner) + } +} + +#[async_trait] +impl ObjectStore for LatencyObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + tokio::time::sleep(self.next_get_latency()).await; + self.inner.get_opts(location, options).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[std::ops::Range], + ) -> Result> { + tokio::time::sleep(self.next_get_latency()).await; + self.inner.get_ranges(location, ranges).await + } + + fn delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + self.inner.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let latency = self.next_list_latency(); + let stream = self.inner.list(prefix); + futures::stream::once(async move { + tokio::time::sleep(latency).await; + futures::stream::empty() + }) + .flatten() + .chain(stream) + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + tokio::time::sleep(self.next_list_latency()).await; + self.inner.list_with_delimiter(prefix).await + } + + async fn copy_opts( + &self, + from: &Path, + to: &Path, + options: CopyOptions, + ) -> Result<()> { + self.inner.copy_opts(from, to, options).await + } +} diff --git a/benchmarks/src/util/memory.rs b/benchmarks/src/util/memory.rs index 944239df31cfd..11b96ef227756 100644 --- a/benchmarks/src/util/memory.rs +++ b/benchmarks/src/util/memory.rs @@ -19,7 +19,7 @@ pub fn print_memory_stats() { #[cfg(all(feature = "mimalloc", feature = "mimalloc_extended"))] { - use datafusion::execution::memory_pool::human_readable_size; + use datafusion_common::human_readable_size; let mut peak_rss = 0; let mut peak_commit = 0; let mut page_faults = 0; diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs index ab4579a566f66..6dc11c0f425bd 100644 --- a/benchmarks/src/util/mod.rs +++ b/benchmarks/src/util/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Shared benchmark utilities +pub mod latency_object_store; mod memory; mod options; mod run; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 6627a287dfcd4..a3e6d2a4c5538 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -17,50 +17,59 @@ use std::{num::NonZeroUsize, sync::Arc}; +use clap::Args; use datafusion::{ execution::{ disk_manager::DiskManagerBuilder, memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, - runtime_env::RuntimeEnvBuilder, + object_store::ObjectStoreUrl, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, }, prelude::SessionConfig, }; use datafusion_common::{DataFusionError, Result}; -use structopt::StructOpt; +use object_store::local::LocalFileSystem; + +use super::latency_object_store::LatencyObjectStore; // Common benchmark options (don't use doc comments otherwise this doc // shows up in help files) -#[derive(Debug, StructOpt, Clone)] +#[derive(Debug, Args, Clone)] pub struct CommonOpt { /// Number of iterations of each test run - #[structopt(short = "i", long = "iterations", default_value = "3")] + #[arg(short = 'i', long = "iterations", default_value = "3", env)] pub iterations: usize, /// Number of partitions to process in parallel. Defaults to number of available cores. - #[structopt(short = "n", long = "partitions")] + #[arg(short = 'n', long = "partitions", env)] pub partitions: Option, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size")] + #[arg(short = 's', long = "batch-size", env)] pub batch_size: Option, /// The memory pool type to use, should be one of "fair" or "greedy" - #[structopt(long = "mem-pool-type", default_value = "fair")] + #[arg(long = "mem-pool-type", default_value = "fair", env)] pub mem_pool_type: String, /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query /// if there's any, otherwise run with no memory limit. - #[structopt(long = "memory-limit", parse(try_from_str = parse_memory_limit))] + #[arg(long = "memory-limit", value_parser = parse_capacity_limit, env)] pub memory_limit: Option, /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used /// if not specified. - #[structopt(long = "sort-spill-reservation-bytes", parse(try_from_str = parse_memory_limit))] + #[arg(long = "sort-spill-reservation-bytes", value_parser = parse_capacity_limit, env)] pub sort_spill_reservation_bytes: Option, /// Activate debug mode to see more details - #[structopt(short, long)] + #[arg(short, long, env)] pub debug: bool, + + /// Simulate object store latency to mimic remote storage (e.g. S3). + /// Adds random latency in the range 20-200ms to each object store operation. + #[arg(long = "simulate-latency", env)] + pub simulate_latency: bool, } impl CommonOpt { @@ -91,7 +100,15 @@ impl CommonOpt { pub fn runtime_env_builder(&self) -> Result { let mut rt_builder = RuntimeEnvBuilder::new(); const NUM_TRACKED_CONSUMERS: usize = 5; - if let Some(memory_limit) = self.memory_limit { + // Use CLI --memory-limit if provided, otherwise fall back to + // DATAFUSION_RUNTIME_MEMORY_LIMIT env var + let memory_limit = self.memory_limit.or_else(|| { + std::env::var("DATAFUSION_RUNTIME_MEMORY_LIMIT") + .ok() + .and_then(|val| parse_capacity_limit(&val).ok()) + }); + + if let Some(memory_limit) = memory_limit { let pool: Arc = match self.mem_pool_type.as_str() { "fair" => Arc::new(TrackConsumersPool::new( FairSpillPool::new(memory_limit), @@ -105,7 +122,7 @@ impl CommonOpt { return Err(DataFusionError::Configuration(format!( "Invalid memory pool type: {}", self.mem_pool_type - ))) + ))); } }; rt_builder = rt_builder @@ -114,22 +131,44 @@ impl CommonOpt { } Ok(rt_builder) } + + /// Build the runtime environment, optionally wrapping the local filesystem + /// with a throttled object store to simulate remote storage latency. + pub fn build_runtime(&self) -> Result> { + let rt = self.runtime_env_builder()?.build_arc()?; + if self.simulate_latency { + let store: Arc = + Arc::new(LatencyObjectStore::new(LocalFileSystem::new())); + let url = ObjectStoreUrl::parse("file:///")?; + rt.register_object_store(url.as_ref(), store); + println!( + "Simulating S3-like object store latency (get: 25-200ms, list: 40-400ms)" + ); + } + Ok(rt) + } } -/// Parse memory limit from string to number of bytes -/// e.g. '1.5G', '100M' -> 1572864 -fn parse_memory_limit(limit: &str) -> Result { +/// Parse capacity limit from string to number of bytes by allowing units: K, M and G. +/// Supports formats like '1.5G' -> 1610612736, '100M' -> 104857600 +fn parse_capacity_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err("Capacity limit cannot be empty".to_string()); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; + .map_err(|_| format!("Failed to parse number from capacity limit '{limit}'"))?; + if number.is_sign_negative() || number.is_infinite() { + return Err("Limit value should be positive finite number".to_string()); + } match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{unit}' in memory limit '{limit}'" + "Unsupported unit '{unit}' in capacity limit '{limit}'. Unit must be one of: 'K', 'M', 'G'" )), } } @@ -139,16 +178,59 @@ mod tests { use super::*; #[test] - fn test_parse_memory_limit_all() { + fn test_runtime_env_builder_reads_env_var() { + // Set the env var and verify runtime_env_builder picks it up + // when no CLI --memory-limit is provided + let opt = CommonOpt { + iterations: 3, + partitions: None, + batch_size: None, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, + debug: false, + simulate_latency: false, + }; + + // With env var set, builder should succeed and have a memory pool + // SAFETY: This test is single-threaded and the env var is restored after use + unsafe { + std::env::set_var("DATAFUSION_RUNTIME_MEMORY_LIMIT", "2G"); + } + let builder = opt.runtime_env_builder().unwrap(); + let runtime = builder.build().unwrap(); + unsafe { + std::env::remove_var("DATAFUSION_RUNTIME_MEMORY_LIMIT"); + } + // A 2G memory pool should be present — verify it reports the correct limit + match runtime.memory_pool.memory_limit() { + datafusion::execution::memory_pool::MemoryLimit::Finite(limit) => { + assert_eq!(limit, 2 * 1024 * 1024 * 1024); + } + _ => panic!("Expected Finite memory limit"), + } + } + + #[test] + fn test_parse_capacity_limit_all() { // Test valid inputs - assert_eq!(parse_memory_limit("100K").unwrap(), 102400); - assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); - assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + assert_eq!(parse_capacity_limit("100K").unwrap(), 102400); + assert_eq!(parse_capacity_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_capacity_limit("2G").unwrap(), 2147483648); // Test invalid unit - assert!(parse_memory_limit("500X").is_err()); + assert!(parse_capacity_limit("500X").is_err()); // Test invalid number - assert!(parse_memory_limit("abcM").is_err()); + assert!(parse_capacity_limit("abcM").is_err()); + + // Test negative number + assert!(parse_capacity_limit("-1M").is_err()); + + // Test infinite number + assert!(parse_capacity_limit("infM").is_err()); + + // Test negative infinite number + assert!(parse_capacity_limit("-infM").is_err()); } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 764ea648ff725..df17674e62961 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{error::Result, DATAFUSION_VERSION}; +use datafusion::{DATAFUSION_VERSION, error::Result}; use datafusion_common::utils::get_available_parallelism; use serde::{Serialize, Serializer}; use serde_json::Value; diff --git a/ci/scripts/changed_crates.sh b/ci/scripts/changed_crates.sh new file mode 100755 index 0000000000000..6d014a9492632 --- /dev/null +++ b/ci/scripts/changed_crates.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Helper script for the breaking-changes-detector workflow. +# +# Subcommands: +# changed-crates +# Print space-separated list of crate names whose files changed vs base_ref. +# Only published workspace members (those without `publish = false`) are +# considered. +# +# semver-check +# Run cargo-semver-checks for the given packages against base_ref. +# Output and exit code are passed through unchanged; the caller is +# responsible for capturing/formatting them. + +set -euo pipefail + +# ── changed-crates ────────────────────────────────────────────────── +cmd_changed_crates() { + local base_ref="${1:?Usage: changed_crates.sh changed-crates }" + + # 1. Files changed between the PR and the base branch. + local changed_files + changed_files=$(git diff --name-only "${base_ref}...HEAD") + + # 2. Every publishable workspace member, one per line as + # " ". `publish = false` in Cargo.toml shows + # up as `"publish": []` in cargo metadata, so filtering on that + # excludes internal crates without a manual exclusion list. + local crates + crates=$(cargo metadata --no-deps --format-version 1 | jq -r ' + (.workspace_root + "/") as $root + | .packages[] + | select(.publish != []) + | "\(.name) \(.manifest_path | ltrimstr($root) | rtrimstr("/Cargo.toml"))" + ') + + # 3. Keep crates whose directory contains a changed file. + while read -r name dir; do + if grep -q "^${dir}/" <<<"$changed_files"; then + echo "$name" + fi + done <<<"$crates" | xargs +} + +# ── semver-check ──────────────────────────────────────────────────── +cmd_semver_check() { + local base_ref="${1:?Usage: changed_crates.sh semver-check }" + shift + + local args=() + for pkg in "$@"; do + args+=(--package "$pkg") + done + + cargo semver-checks --baseline-rev "$base_ref" "${args[@]}" +} + +# ── main ──────────────────────────────────────────────────────────── +cmd="${1:?Usage: changed_crates.sh [args...]}" +shift + +case "$cmd" in + changed-crates) cmd_changed_crates "$@" ;; + semver-check) cmd_semver_check "$@" ;; + *) echo "Unknown command: $cmd" >&2; exit 1 ;; +esac diff --git a/ci/scripts/check_asf_yaml_status_checks.py b/ci/scripts/check_asf_yaml_status_checks.py new file mode 100644 index 0000000000000..135654159051c --- /dev/null +++ b/ci/scripts/check_asf_yaml_status_checks.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Validate that every entry in .asf.yaml required_status_checks +matches an actual GitHub Actions job name, and that the workflow +is not filtered by paths/paths-ignore (which would prevent the +check from running on some PRs, blocking merges). + +A typo or stale entry in required_status_checks will block all +merges for the project, so this check catches that early. +""" + +import glob +import os +import sys + +import yaml + + +def get_required_checks(asf_yaml_path): + """Extract all required_status_checks contexts from .asf.yaml.""" + with open(asf_yaml_path) as f: + config = yaml.safe_load(f) + + checks = {} # context -> list of branches requiring it + branches = config.get("github", {}).get("protected_branches", {}) + for branch, settings in branches.items(): + contexts = ( + settings.get("required_status_checks", {}).get("contexts", []) + ) + for ctx in contexts: + checks.setdefault(ctx, []).append(branch) + + return checks + + +def get_workflow_jobs(workflows_dir): + """Collect all jobs with their metadata from GitHub Actions workflow files. + + Returns a dict mapping job identifier (name or key) to a list of + (workflow_file, has_path_filters) tuples. + """ + jobs = {} # identifier -> [(workflow_file, has_path_filters)] + for workflow_file in sorted(glob.glob(os.path.join(workflows_dir, "*.yml"))): + with open(workflow_file) as f: + workflow = yaml.safe_load(f) + + if not workflow or "jobs" not in workflow: + continue + + # Check if pull_request trigger has path filters + on = workflow.get(True, workflow.get("on", {})) # yaml parses `on:` as True + pr_trigger = on.get("pull_request", {}) if isinstance(on, dict) else {} + has_path_filters = bool( + isinstance(pr_trigger, dict) + and (pr_trigger.get("paths") or pr_trigger.get("paths-ignore")) + ) + + basename = os.path.basename(workflow_file) + for job_key, job_config in workflow.get("jobs", {}).items(): + if not isinstance(job_config, dict): + continue + job_name = job_config.get("name", job_key) + info = (basename, has_path_filters) + jobs.setdefault(job_name, []).append(info) + if job_key != job_name: + jobs.setdefault(job_key, []).append(info) + + return jobs + + +def main(): + repo_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + asf_yaml = os.path.join(repo_root, ".asf.yaml") + workflows_dir = os.path.join(repo_root, ".github", "workflows") + + required_checks = get_required_checks(asf_yaml) + if not required_checks: + print("No required_status_checks found in .asf.yaml — nothing to validate.") + return + + jobs = get_workflow_jobs(workflows_dir) + errors = [] + + for ctx in sorted(required_checks): + branches = ", ".join(sorted(required_checks[ctx])) + if ctx not in jobs: + errors.append( + f' - "{ctx}" (branch: {branches}): ' + f"not found in any GitHub Actions workflow" + ) + continue + + # Check if ALL workflows providing this job have path filters + # (if at least one doesn't, the check will still run) + filtered_workflows = [ + wf for wf, has_filter in jobs[ctx] if has_filter + ] + unfiltered_workflows = [ + wf for wf, has_filter in jobs[ctx] if not has_filter + ] + if filtered_workflows and not unfiltered_workflows: + wf_list = ", ".join(filtered_workflows) + errors.append( + f' - "{ctx}" (branch: {branches}): ' + f"workflow {wf_list} uses paths/paths-ignore filters on " + f"pull_request, so this check won't run for some PRs " + f"and will block merging" + ) + + if errors: + print("ERROR: Problems found with required_status_checks in .asf.yaml:\n") + print("\n".join(errors)) + print() + print("Available job names across all workflows:") + for name in sorted(jobs): + print(f" - {name}") + sys.exit(1) + + print( + f"OK: All {len(required_checks)} required_status_checks " + "match existing GitHub Actions jobs." + ) + + +if __name__ == "__main__": + main() diff --git a/ci/scripts/check_examples_docs.sh b/ci/scripts/check_examples_docs.sh new file mode 100755 index 0000000000000..62308b323b535 --- /dev/null +++ b/ci/scripts/check_examples_docs.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Generates documentation for DataFusion examples using the Rust-based +# documentation generator and verifies that the committed README.md +# is up to date. +# +# The README is generated from documentation comments in: +# datafusion-examples/examples//main.rs +# +# This script is intended to be run in CI to ensure that example +# documentation stays in sync with the code. +# +# To update the README locally, run this script and replace README.md +# with the generated output. + +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" + +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" + +EXAMPLES_DIR="$ROOT_DIR/datafusion-examples" +README="$EXAMPLES_DIR/README.md" +README_NEW="$EXAMPLES_DIR/README-NEW.md" + +echo "▶ Generating examples README (Rust generator)…" +cargo run --quiet \ + --manifest-path "$EXAMPLES_DIR/Cargo.toml" \ + --bin examples-docs \ + > "$README_NEW" + +echo "▶ Formatting generated README with prettier ${PRETTIER_VERSION}…" +npx "prettier@${PRETTIER_VERSION}" \ + --parser markdown \ + --write "$README_NEW" + +echo "▶ Comparing generated README with committed version…" + +if ! diff -u "$README" "$README_NEW" > /tmp/examples-readme.diff; then + echo "" + echo "❌ Examples README is out of date." + echo "" + echo "The examples documentation is generated automatically from:" + echo " - datafusion-examples/examples//main.rs" + echo "" + echo "To update the README locally, run:" + echo "" + echo " cargo run --bin examples-docs \\" + echo " | npx prettier@${PRETTIER_VERSION} --parser markdown --write \\" + echo " > datafusion-examples/README.md" + echo "" + echo "Diff:" + echo "------------------------------------------------------------" + cat /tmp/examples-readme.diff + echo "------------------------------------------------------------" + exit 1 +fi + +echo "✅ Examples README is up-to-date." diff --git a/ci/scripts/doc_prettier_check.sh b/ci/scripts/doc_prettier_check.sh new file mode 100755 index 0000000000000..95332eb65aaf2 --- /dev/null +++ b/ci/scripts/doc_prettier_check.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +# Load shared utilities and tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" +source "${ROOT_DIR}/ci/scripts/utils/git.sh" + +PRETTIER_TARGETS=( + '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' + '!datafusion/CHANGELOG.md' + README.md + CONTRIBUTING.md +) + +MODE="check" +ALLOW_DIRTY=0 + +usage() { + cat >&2 </dev/null 2>&1; then + echo "npx is required to run the prettier check. Install Node.js (e.g., brew install node) and re-run." >&2 + exit 1 +fi + +PRETTIER_MODE=(--check) +if [[ "$MODE" == "write" ]]; then + PRETTIER_MODE=(--write) +fi + +# Ignore subproject CHANGELOG.md because it is machine generated +npx "prettier@${PRETTIER_VERSION}" "${PRETTIER_MODE[@]}" "${PRETTIER_TARGETS[@]}" diff --git a/ci/scripts/license_header.sh b/ci/scripts/license_header.sh index 5345728f9cdf0..7ab8c9637598b 100755 --- a/ci/scripts/license_header.sh +++ b/ci/scripts/license_header.sh @@ -17,6 +17,62 @@ # specific language governing permissions and limitations # under the License. -# Check Apache license header -set -ex -hawkeye check --config licenserc.toml +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +source "${SCRIPT_DIR}/utils/git.sh" + +MODE="check" +ALLOW_DIRTY=0 +HAWKEYE_CONFIG="licenserc.toml" + +usage() { + cat >&2 <&2 <&2 <&2 <&2 <&2 + return 1 + fi +} diff --git a/ci/scripts/utils/tool_versions.sh b/ci/scripts/utils/tool_versions.sh new file mode 100644 index 0000000000000..237b18b62ef40 --- /dev/null +++ b/ci/scripts/utils/tool_versions.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file defines centralized tool versions used by CI and development scripts. +# It is intended to be sourced by other scripts and should not be executed directly. + +PRETTIER_VERSION="2.7.1" +LYCHEE_VERSION="0.23.0" diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f3069b492352d..414b8c6444869 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,10 +37,10 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.8.7" -aws-credential-types = "1.2.7" +aws-config = "1.8.16" +aws-credential-types = "1.2.13" chrono = { workspace = true } -clap = { version = "4.5.50", features = ["cargo", "derive"] } +clap = { version = "4.5.60", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", @@ -65,14 +65,24 @@ object_store = { workspace = true, features = ["aws", "gcp", "http"] } parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } -rustyline = "17.0" +rustyline = "18.0" tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] } url = { workspace = true } +[lints] +workspace = true + [dev-dependencies] ctor = { workspace = true } insta = { workspace = true } insta-cmd = "0.6.0" rstest = { workspace = true } -testcontainers = { workspace = true } testcontainers-modules = { workspace = true, features = ["minio"] } +# Makes sure `test_display_pg_json` behaves in a consistent way regardless of +# feature unification with dependencies +serde_json = { workspace = true, features = ["preserve_order"] } + +# Required because we pull serde_json with a feature to get consistent pg display, +# but its not directly used. +[package.metadata.cargo-machete] +ignored = "serde_json" diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs index bd2dbb736781f..6095072163870 100644 --- a/datafusion-cli/examples/cli-session-context.rs +++ b/datafusion-cli/examples/cli-session-context.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::{LogicalPlan, LogicalPlanBuilder}, prelude::SessionContext, }; diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 20d62eabc3901..185dfb6b08006 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::{Arc, Weak}; -use crate::object_storage::{get_object_store, AwsOptions, GcpOptions}; +use crate::object_storage::{AwsOptions, GcpOptions, get_object_store}; use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}; use datafusion::common::plan_datafusion_err; -use datafusion::datasource::listing::ListingTableUrl; use datafusion::datasource::TableProvider; +use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::Result; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; @@ -50,10 +49,6 @@ impl DynamicObjectStoreCatalog { } impl CatalogProviderList for DynamicObjectStoreCatalog { - fn as_any(&self) -> &dyn Any { - self - } - fn register_catalog( &self, name: String, @@ -91,10 +86,6 @@ impl DynamicObjectStoreCatalogProvider { } impl CatalogProvider for DynamicObjectStoreCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema_names(&self) -> Vec { self.inner.schema_names() } @@ -134,10 +125,6 @@ impl DynamicObjectStoreSchemaProvider { #[async_trait] impl SchemaProvider for DynamicObjectStoreSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { self.inner.table_names() } @@ -152,10 +139,10 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { async fn table(&self, name: &str) -> Result>> { let inner_table = self.inner.table(name).await; - if inner_table.is_ok() { - if let Some(inner_table) = inner_table? { - return Ok(Some(inner_table)); - } + if inner_table.is_ok() + && let Some(inner_table) = inner_table? + { + return Ok(Some(inner_table)); } // if the inner schema provider didn't have a table by @@ -219,12 +206,12 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { } pub fn substitute_tilde(cur: String) -> String { - if let Some(usr_dir_path) = home_dir() { - if let Some(usr_dir) = usr_dir_path.to_str() { - if cur.starts_with('~') && !usr_dir.is_empty() { - return cur.replacen('~', usr_dir, 1); - } - } + if let Some(usr_dir_path) = home_dir() + && let Some(usr_dir) = usr_dir_path.to_str() + && cur.starts_with('~') + && !usr_dir.is_empty() + { + return cur.replacen('~', usr_dir, 1); } cur } @@ -359,10 +346,12 @@ mod tests { } else { "/home/user" }; - env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - test_home_path, - ); + unsafe { + env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + test_home_path, + ); + } let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; let expected = PathBuf::from(test_home_path) .join("Code") @@ -376,12 +365,16 @@ mod tests { .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); - match original_home { - Some(home_path) => env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - home_path.to_str().unwrap(), - ), - None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }), + unsafe { + match original_home { + Some(home_path) => env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + home_path.to_str().unwrap(), + ), + None => { + env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }) + } + } } } } diff --git a/datafusion-cli/src/cli_context.rs b/datafusion-cli/src/cli_context.rs index 516929ebacf19..a6320f03fe4de 100644 --- a/datafusion-cli/src/cli_context.rs +++ b/datafusion-cli/src/cli_context.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::LogicalPlan, prelude::SessionContext, }; diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 3fbfe5680cfcd..8aaa8025d1c3a 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -19,7 +19,7 @@ use crate::cli_context::CliSessionContext; use crate::exec::{exec_and_print, exec_from_lines}; -use crate::functions::{display_all_functions, Function}; +use crate::functions::{Function, display_all_functions}; use crate::print_format::PrintFormat; use crate::print_options::PrintOptions; use clap::ValueEnum; diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index d079a88a6440e..09347d6d7dc2c 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -35,19 +35,19 @@ use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; use datafusion::physical_plan::spill::get_record_batch_memory_size; -use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; +use datafusion::physical_plan::{ExecutionPlanProperties, execute_stream}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser; use datafusion::sql::sqlparser::dialect::dialect_from_str; use futures::StreamExt; use log::warn; use object_store::Error::Generic; -use rustyline::error::ReadlineError; use rustyline::Editor; +use rustyline::error::ReadlineError; use std::collections::HashMap; use std::fs::File; -use std::io::prelude::*; use std::io::BufReader; +use std::io::prelude::*; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -153,7 +153,7 @@ pub async fn exec_from_repl( } } else { eprintln!( - "'\\{}' is not a valid command", + "'\\{}' is not a valid command, you can use '\\?' to see all commands", &line[1..] ); } @@ -168,7 +168,10 @@ pub async fn exec_from_repl( } } } else { - eprintln!("'\\{}' is not a valid command", &line[1..]); + eprintln!( + "'\\{}' is not a valid command, you can use '\\?' to see all commands", + &line[1..] + ); } } Ok(line) => { @@ -193,6 +196,7 @@ pub async fn exec_from_repl( } Err(ReadlineError::Interrupted) => { println!("^C"); + rl.helper().unwrap().reset_hint(); continue; } Err(ReadlineError::Eof) => { @@ -266,7 +270,7 @@ impl StatementExecutor { let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded - let mut reservation = + let reservation = MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool()); if physical_plan.boundedness().is_unbounded() { @@ -297,7 +301,7 @@ impl StatementExecutor { let curr_num_rows = batch.num_rows(); // Stop collecting results if the number of rows exceeds the limit // results batch should include the last batch that exceeds the limit - if row_count < max_rows + curr_num_rows { + if row_count < max_rows.saturating_add(curr_num_rows) { // Try to grow the reservation to accommodate the batch in memory reservation.try_grow(get_record_batch_memory_size(&batch))?; results.push(batch); @@ -334,7 +338,9 @@ impl StatementExecutor { if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store)) && self.statement_for_retry.is_some() => { - warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."); + warn!( + "S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration." + ); let plan = create_plan(ctx, self.statement_for_retry.take().unwrap(), true) .await?; @@ -516,6 +522,7 @@ mod tests { use datafusion::common::plan_err; use datafusion::prelude::SessionContext; + use datafusion_common::assert_contains; use url::Url; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { @@ -699,8 +706,7 @@ mod tests { #[tokio::test] async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -710,15 +716,16 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); // for service_account_key - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'" + ); let err = create_external_table_test(location, &sql) .await - .unwrap_err() - .to_string(); - assert!(err.contains("No RSA key found in pem file"), "{err}"); + .unwrap_err(); + assert_contains!(err.to_string(), "Error reading pem file: no items found"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -726,7 +733,7 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); Ok(()) } @@ -748,8 +755,9 @@ mod tests { let location = "path/to/file.cvs"; // Test with format options - let sql = - format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')" + ); create_external_table_test(location, &sql).await.unwrap(); Ok(()) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index d23b12469e385..26f007cdd3193 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -17,19 +17,24 @@ //! Functions that are query-able and searchable via the `\h` command +use datafusion_common::instant::Instant; use std::fmt; use std::fs::File; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Int64Array, StringArray, TimestampMillisecondArray, UInt64Array}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::array::{ + DurationMillisecondArray, GenericListArray, Int64Array, StringArray, StructArray, + TimestampMillisecondArray, UInt64Array, +}; +use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion::catalog::{Session, TableFunctionImpl}; -use datafusion::common::{plan_err, Column}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl}; +use datafusion::common::{Column, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::cache::cache_manager::CacheManager; use datafusion::logical_expr::Expr; @@ -224,11 +229,7 @@ struct ParquetMetadataTable { #[async_trait] impl TableProvider for ParquetMetadataTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -321,7 +322,8 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); let filename = match exprs.first() { Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") @@ -421,7 +423,7 @@ impl TableFunctionImpl for ParquetMetadataFunc { compression_arr.push(format!("{:?}", column.compression())); // need to collect into Vec to format let encodings: Vec<_> = column.encodings().collect(); - encodings_arr.push(format!("{:?}", encodings)); + encodings_arr.push(format!("{encodings:?}")); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); @@ -473,11 +475,7 @@ struct MetadataCacheTable { #[async_trait] impl TableProvider for MetadataCacheTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -512,7 +510,8 @@ impl MetadataCacheFunc { } impl TableFunctionImpl for MetadataCacheFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); if !exprs.is_empty() { return plan_err!("metadata_cache should have no arguments"); } @@ -581,3 +580,292 @@ impl TableFunctionImpl for MetadataCacheFunc { Ok(Arc::new(metadata_cache)) } } + +/// STATISTICS_CACHE table function +#[derive(Debug)] +struct StatisticsCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for StatisticsCacheTable { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct StatisticsCacheFunc { + cache_manager: Arc, +} + +impl StatisticsCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for StatisticsCacheFunc { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); + if !exprs.is_empty() { + return plan_err!("statistics_cache should have no arguments"); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + Field::new("num_rows", DataType::Utf8, false), + Field::new("num_columns", DataType::UInt64, false), + Field::new("table_size_bytes", DataType::Utf8, false), + Field::new("statistics_size_bytes", DataType::UInt64, false), + ])); + + // construct record batch from metadata + let mut path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut e_tag_arr = vec![]; + let mut version_arr = vec![]; + let mut num_rows_arr = vec![]; + let mut num_columns_arr = vec![]; + let mut table_size_bytes_arr = vec![]; + let mut statistics_size_bytes_arr = vec![]; + + if let Some(file_statistics_cache) = self.cache_manager.get_file_statistic_cache() + { + for (path, entry) in file_statistics_cache.list_entries() { + path_arr.push(path.to_string()); + file_modified_arr + .push(Some(entry.object_meta.last_modified.timestamp_millis())); + file_size_bytes_arr.push(entry.object_meta.size); + e_tag_arr.push(entry.object_meta.e_tag); + version_arr.push(entry.object_meta.version); + num_rows_arr.push(entry.num_rows.to_string()); + num_columns_arr.push(entry.num_columns as u64); + table_size_bytes_arr.push(entry.table_size_bytes.to_string()); + statistics_size_bytes_arr.push(entry.statistics_size_bytes as u64); + } + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(e_tag_arr)), + Arc::new(StringArray::from(version_arr)), + Arc::new(StringArray::from(num_rows_arr)), + Arc::new(UInt64Array::from(num_columns_arr)), + Arc::new(StringArray::from(table_size_bytes_arr)), + Arc::new(UInt64Array::from(statistics_size_bytes_arr)), + ], + )?; + + let statistics_cache = StatisticsCacheTable { schema, batch }; + Ok(Arc::new(statistics_cache)) + } +} + +/// Implementation of the `list_files_cache` table function in datafusion-cli. +/// +/// This function returns the cached results of running a LIST command on a +/// particular object store path for a table. The object metadata is returned as +/// a List of Structs, with one Struct for each object. DataFusion uses these +/// cached results to plan queries against external tables. +/// +/// # Schema +/// ```sql +/// > describe select * from list_files_cache(); +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | column_name | data_type | is_nullable | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | table | Utf8 | NO | +/// | path | Utf8 | NO | +/// | metadata_size_bytes | UInt64 | NO | +/// | expires_in | Duration(ms) | YES | +/// | metadata_list | List(Struct("file_path": non-null Utf8, "file_modified": non-null Timestamp(ms), "file_size_bytes": non-null UInt64, "e_tag": Utf8, "version": Utf8), field: 'metadata') | YES | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// ``` +#[derive(Debug)] +struct ListFilesCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ListFilesCacheTable { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct ListFilesCacheFunc { + cache_manager: Arc, +} + +impl ListFilesCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for ListFilesCacheFunc { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); + if !exprs.is_empty() { + return plan_err!("list_files_cache should have no arguments"); + } + + let nested_fields = Fields::from(vec![ + Field::new("file_path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + ]); + + let metadata_field = + Field::new("metadata", DataType::Struct(nested_fields.clone()), true); + + let schema = Arc::new(Schema::new(vec![ + Field::new("table", DataType::Utf8, true), + Field::new("path", DataType::Utf8, false), + Field::new("metadata_size_bytes", DataType::UInt64, false), + // expires field in ListFilesEntry has type Instant when set, from which we cannot get "the number of seconds", hence using Duration instead of Timestamp as data type. + Field::new( + "expires_in", + DataType::Duration(TimeUnit::Millisecond), + true, + ), + Field::new( + "metadata_list", + DataType::List(Arc::new(metadata_field.clone())), + true, + ), + ])); + + let mut table_arr = vec![]; + let mut path_arr = vec![]; + let mut metadata_size_bytes_arr = vec![]; + let mut expires_arr = vec![]; + + let mut file_path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut etag_arr = vec![]; + let mut version_arr = vec![]; + let mut offsets: Vec = vec![0]; + + if let Some(list_files_cache) = self.cache_manager.get_list_files_cache() { + let now = Instant::now(); + let mut current_offset: i32 = 0; + + for (path, entry) in list_files_cache.list_entries() { + table_arr.push(path.table.map(|t| t.to_string())); + path_arr.push(path.path.to_string()); + metadata_size_bytes_arr.push(entry.size_bytes as u64); + // calculates time left before entry expires + expires_arr.push( + entry + .expires + .map(|t| t.duration_since(now).as_millis() as i64), + ); + + for meta in entry.metas.files.iter() { + file_path_arr.push(meta.location.to_string()); + file_modified_arr.push(meta.last_modified.timestamp_millis()); + file_size_bytes_arr.push(meta.size); + etag_arr.push(meta.e_tag.clone()); + version_arr.push(meta.version.clone()); + } + current_offset += entry.metas.files.len() as i32; + offsets.push(current_offset); + } + } + + let struct_arr = StructArray::new( + nested_fields, + vec![ + Arc::new(StringArray::from(file_path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(etag_arr)), + Arc::new(StringArray::from(version_arr)), + ], + None, + ); + + let offsets_buffer: OffsetBuffer = + OffsetBuffer::new(ScalarBuffer::from(Buffer::from_vec(offsets))); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(table_arr)), + Arc::new(StringArray::from(path_arr)), + Arc::new(UInt64Array::from(metadata_size_bytes_arr)), + Arc::new(DurationMillisecondArray::from(expires_arr)), + Arc::new(GenericListArray::new( + Arc::new(metadata_field), + offsets_buffer, + Arc::new(struct_arr), + None, + )), + ], + )?; + + let list_files_cache = ListFilesCacheTable { schema, batch }; + Ok(Arc::new(list_files_cache)) + } +} diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 219637b3460e6..67e203cf7987b 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -19,8 +19,9 @@ //! and auto-completion for file name during creating external table. use std::borrow::Cow; +use std::cell::Cell; -use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; +use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -33,10 +34,17 @@ use rustyline::hint::Hinter; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Context, Helper, Result}; +/// Default suggestion shown when the input line is empty. +const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit"; + pub struct CliHelper { completer: FilenameCompleter, dialect: Dialect, highlighter: Box, + /// Tracks whether to show the default hint. Set to `false` once the user + /// types anything, so the hint doesn't reappear after deleting back to + /// an empty line. Reset to `true` when the line is submitted. + show_hint: Cell, } impl CliHelper { @@ -50,6 +58,7 @@ impl CliHelper { completer: FilenameCompleter::new(), dialect: *dialect, highlighter, + show_hint: Cell::new(true), } } @@ -59,6 +68,11 @@ impl CliHelper { } } + /// Re-enable the default hint for the next prompt. + pub fn reset_hint(&self) { + self.show_hint.set(true); + } + fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { let dialect = match dialect_from_str(self.dialect) { @@ -67,7 +81,7 @@ impl CliHelper { return Ok(ValidationResult::Invalid(Some(format!( " 🤔 Invalid dialect: {}", self.dialect - )))) + )))); } }; let lines = split_from_semicolon(sql); @@ -110,10 +124,22 @@ impl Highlighter for CliHelper { fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } + + fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { + Color::gray(hint).into() + } } impl Hinter for CliHelper { type Hint = String; + + fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option { + if !line.is_empty() { + self.show_hint.set(false); + } + (self.show_hint.get() && line.trim().is_empty()) + .then(|| DEFAULT_HINT_SUGGESTION.to_owned()) + } } /// returns true if the current position is after the open quote for @@ -121,12 +147,9 @@ impl Hinter for CliHelper { fn is_open_quote_for_location(line: &str, pos: usize) -> bool { let mut sql = line[..pos].to_string(); sql.push('\''); - if let Ok(stmts) = DFParser::parse_sql(&sql) { - if let Some(Statement::CreateExternalTable(_)) = stmts.back() { - return true; - } - } - false + DFParser::parse_sql(&sql).is_ok_and(|stmts| { + matches!(stmts.back(), Some(Statement::CreateExternalTable(_))) + }) } impl Completer for CliHelper { @@ -149,7 +172,9 @@ impl Completer for CliHelper { impl Validator for CliHelper { fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result { let input = ctx.input().trim_end(); - self.validate_input(input) + let result = self.validate_input(input); + self.reset_hint(); + result } } diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index f4e57a2e3593a..adcb135bb401f 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -23,7 +23,7 @@ use std::{ }; use datafusion::sql::sqlparser::{ - dialect::{dialect_from_str, Dialect, GenericDialect}, + dialect::{Dialect, GenericDialect, dialect_from_str}, keywords::Keyword, tokenizer::{Token, Tokenizer}, }; @@ -38,7 +38,8 @@ pub struct SyntaxHighlighter { impl SyntaxHighlighter { pub fn new(dialect: &config::Dialect) -> Self { - let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {})); + let dialect = + dialect_from_str(dialect).unwrap_or_else(|| Box::new(GenericDialect {})); Self { dialect } } } @@ -80,22 +81,26 @@ impl Highlighter for SyntaxHighlighter { } /// Convenient utility to return strings with [ANSI color](https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124). -struct Color {} +pub(crate) struct Color {} impl Color { - fn green(s: impl Display) -> String { + pub(crate) fn green(s: impl Display) -> String { format!("\x1b[92m{s}\x1b[0m") } - fn red(s: impl Display) -> String { + pub(crate) fn red(s: impl Display) -> String { format!("\x1b[91m{s}\x1b[0m") } + + pub(crate) fn gray(s: impl Display) -> String { + format!("\x1b[90m{s}\x1b[0m") + } } #[cfg(test)] mod tests { - use super::config::Dialect; use super::SyntaxHighlighter; + use super::config::Dialect; use rustyline::highlight::Highlighter; #[test] diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 09fa8ef15af84..6bfe1160ecdd6 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -31,16 +31,17 @@ use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::logical_expr::ExplainFormat; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; -use datafusion_cli::functions::{MetadataCacheFunc, ParquetMetadataFunc}; +use datafusion_cli::functions::{ + ListFilesCacheFunc, MetadataCacheFunc, ParquetMetadataFunc, StatisticsCacheFunc, +}; use datafusion_cli::object_storage::instrumented::{ InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, }; use datafusion_cli::{ - exec, + DATAFUSION_CLI_VERSION, exec, pool_type::PoolType, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, - DATAFUSION_CLI_VERSION, }; use clap::Parser; @@ -244,6 +245,21 @@ async fn main_inner() -> Result<()> { )), ); + // register `statistics_cache` table function to get the contents of the file statistics cache + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, @@ -422,9 +438,20 @@ pub fn extract_disk_limit(size: &str) -> Result { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; - use datafusion::{common::test_util::batches_to_string, prelude::ParquetReadOptions}; + use datafusion::{ + common::test_util::batches_to_string, + execution::cache::{ + DefaultListFilesCache, cache_manager::CacheManagerConfig, + cache_unit::DefaultFileStatisticsCache, + }, + prelude::{ParquetReadOptions, col, lit, split_part}, + }; use insta::assert_snapshot; + use object_store::memory::InMemory; + use url::Url; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); @@ -488,8 +515,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with single quote - let sql = - "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -502,8 +528,7 @@ mod tests { "#); // input with double quote - let sql = - "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let sql = "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; assert_snapshot!(batches_to_string(&rbs), @r#" @@ -523,8 +548,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with string columns - let sql = - "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -592,9 +616,9 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 6957 | 2 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 996 | 2 | page_index=false | + | alltypes_plain.parquet | 1851 | 8882 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 2 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); @@ -623,12 +647,205 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 6957 | 5 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 996 | 3 | page_index=false | + | alltypes_plain.parquet | 1851 | 8882 | 5 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 3 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); Ok(()) } + + /// Shows that the statistics cache is not enabled by default yet + /// See https://github.com/apache/datafusion/issues/19217 + #[tokio::test] + async fn test_statistics_cache_default() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + // When the cache manager creates a StatisticsCache by default, + // the contents will show up here + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + ++ + ++ + "); + + Ok(()) + } + + // Can be removed when https://github.com/apache/datafusion/issues/19217 is resolved + #[tokio::test] + async fn test_statistics_cache_override() -> Result<(), DataFusionError> { + // Install a specific StatisticsCache implementation + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build()?; + let config = SessionConfig::new().with_collect_statistics(true); + let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | filename | file_size_bytes | num_rows | num_columns | table_size_bytes | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | alltypes_plain.parquet | 1851 | Exact(8) | 11 | Absent | + | alltypes_tiny_pages.parquet | 454233 | Exact(7300) | 13 | Absent | + | lz4_raw_compressed_larger.parquet | 380836 | Exact(10000) | 1 | Absent | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + "); + + Ok(()) + } + + #[tokio::test] + async fn test_list_files_cache() -> Result<(), DataFusionError> { + let list_files_cache = Arc::new(DefaultListFilesCache::new( + 1024, + Some(Duration::from_secs(1)), + )); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default() + .with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + ctx.register_object_store( + &Url::parse("mem://test_table").unwrap(), + Arc::new(InMemory::new()), + ); + + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.sql( + "CREATE EXTERNAL TABLE src_table + STORED AS PARQUET + LOCATION '../parquet-testing/data/alltypes_plain.parquet'", + ) + .await? + .collect() + .await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/0.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/1.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql( + "CREATE EXTERNAL TABLE test_table + STORED AS PARQUET + LOCATION 'mem://test_table/' + ", + ) + .await? + .collect() + .await?; + + let sql = "SELECT metadata_size_bytes, expires_in, metadata_list FROM list_files_cache()"; + let df = ctx + .sql(sql) + .await? + .unnest_columns(&["metadata_list"])? + .with_column_renamed("metadata_list", "metadata")? + .unnest_columns(&["metadata"])?; + + assert_eq!( + 2, + df.clone() + .filter(col("expires_in").is_not_null())? + .count() + .await? + ); + + let df = df + .with_column_renamed(r#""metadata.file_size_bytes""#, "file_size_bytes")? + .with_column_renamed(r#""metadata.e_tag""#, "etag")? + .with_column( + "filename", + split_part(col(r#""metadata.file_path""#), lit("/"), lit(-1)), + )? + .select_columns(&[ + "metadata_size_bytes", + "filename", + "file_size_bytes", + "etag", + ])? + .sort(vec![col("filename").sort(true, false)])?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +---------------------+-----------+-----------------+------+ + | metadata_size_bytes | filename | file_size_bytes | etag | + +---------------------+-----------+-----------------+------+ + | 212 | 0.parquet | 3642 | 0 | + | 212 | 1.parquet | 3642 | 1 | + +---------------------+-----------+-----------------+------+ + "); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index e6e6be42c7ad0..34787838929f1 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -20,7 +20,7 @@ pub mod instrumented; use async_trait::async_trait; use aws_config::BehaviorVersion; use aws_credential_types::provider::{ - error::CredentialsError, ProvideCredentials, SharedCredentialsProvider, + ProvideCredentials, SharedCredentialsProvider, error::CredentialsError, }; use datafusion::{ common::{ @@ -33,12 +33,12 @@ use datafusion::{ }; use log::debug; use object_store::{ - aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, - gcp::GoogleCloudStorageBuilder, - http::HttpBuilder, ClientOptions, CredentialProvider, Error::Generic, ObjectStore, + aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, + gcp::GoogleCloudStorageBuilder, + http::HttpBuilder, }; use std::{ any::Any, @@ -64,6 +64,21 @@ pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, resolve_region: bool, +) -> Result { + // Box the inner future to reduce the future size of this async function, + // which is deeply nested in the CLI's async call chain. + Box::pin(get_s3_object_store_builder_inner( + url, + aws_options, + resolve_region, + )) + .await +} + +async fn get_s3_object_store_builder_inner( + url: &Url, + aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -124,14 +139,15 @@ pub async fn get_s3_object_store_builder( if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" - if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { - if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { - return config_err!( - "Invalid endpoint: {endpoint}. \ + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) + && !matches!(allow_http, Some(true)) + && endpoint_url.scheme() == "http" + { + return config_err!( + "Invalid endpoint: {endpoint}. \ HTTP is not allowed for S3 endpoints. \ To allow HTTP, set 'aws.allow_http' to true" - ); - } + ); } builder = builder.with_endpoint(endpoint); @@ -208,7 +224,7 @@ impl CredentialsFromConfig { #[derive(Debug)] struct S3CredentialProvider { - credentials: aws_credential_types::provider::SharedCredentialsProvider, + credentials: SharedCredentialsProvider, } #[async_trait] @@ -586,8 +602,10 @@ mod tests { let location = "s3://bucket/path/FAKE/file.parquet"; // Set it to a non-existent file to avoid reading the default configuration file - std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); - std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + } // No options let table_url = ListingTableUrl::parse(location)?; @@ -716,7 +734,10 @@ mod tests { .await .unwrap_err(); - assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); + assert_eq!( + err.to_string().lines().next().unwrap_or_default(), + "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true" + ); // Now add `allow_http` to the options and check if it works let sql = format!( @@ -743,10 +764,11 @@ mod tests { eprintln!("{e}"); return Ok(()); } - let expected_region = "eu-central-1"; let location = "s3://test-bucket/path/file.parquet"; // Set it to a non-existent file to avoid reading the default configuration file - std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + } let table_url = ListingTableUrl::parse(location)?; let aws_options = AwsOptions { @@ -758,17 +780,18 @@ mod tests { get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; // Verify that the region was auto-detected in test environment - assert_eq!( - builder.get_config_value(&AmazonS3ConfigKey::Region), - Some(expected_region.to_string()) + assert!( + builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_some() ); Ok(()) } #[tokio::test] - async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled( - ) -> Result<()> { + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled() + -> Result<()> { if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { // Skip test if AWS envs are not set eprintln!("{e}"); @@ -806,7 +829,9 @@ mod tests { let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); ctx.register_table_options_extension_from_scheme(scheme); @@ -830,14 +855,15 @@ mod tests { #[tokio::test] async fn gcs_object_store_builder() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); ctx.register_table_options_extension_from_scheme(scheme); diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs index c4b63b417fe42..a0321cacb374b 100644 --- a/datafusion-cli/src/object_storage/instrumented.rs +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -20,8 +20,8 @@ use std::{ ops::AddAssign, str::FromStr, sync::{ - atomic::{AtomicU8, Ordering}, Arc, + atomic::{AtomicU8, AtomicU64, Ordering}, }, time::Duration, }; @@ -31,18 +31,67 @@ use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use chrono::Utc; use datafusion::{ - common::{instant::Instant, HashMap}, + common::{HashMap, instant::Instant}, error::DataFusionError, execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; -use futures::stream::BoxStream; +use futures::stream::{BoxStream, Stream}; +use futures::{StreamExt, TryStreamExt}; use object_store::{ - path::Path, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, Result, path::Path, }; use parking_lot::{Mutex, RwLock}; use url::Url; +/// A stream wrapper that measures the time until the first response(item or end of stream) is yielded. +/// +/// The timer starts on the first `poll_next` call (not at stream creation) to avoid +/// measuring unrelated work between stream creation and first poll. +/// Duration is stored as nanoseconds in an `AtomicU64` (0 = not yet set). +struct TimeToFirstItemStream { + inner: S, + start: Option, + request_duration: Arc, + duration_recorded: bool, +} + +impl TimeToFirstItemStream { + fn new(inner: S, request_duration: Arc) -> Self { + Self { + inner, + start: None, + request_duration, + duration_recorded: false, + } + } +} + +impl Stream for TimeToFirstItemStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let start = *self.start.get_or_insert_with(Instant::now); + + let poll_result = std::pin::Pin::new(&mut self.inner).poll_next(cx); + + if !self.duration_recorded && poll_result.is_ready() { + self.duration_recorded = true; + let nanos = start.elapsed().as_nanos() as u64; + self.request_duration.store(nanos, Ordering::Release); + } + + poll_result + } +} + /// The profiling mode to use for an [`InstrumentedObjectStore`] instance. Collecting profiling /// data will have a small negative impact on both CPU and memory usage. Default is `Disabled` #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] @@ -57,7 +106,7 @@ pub enum InstrumentedObjectStoreMode { } impl fmt::Display for InstrumentedObjectStoreMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:?}") } } @@ -91,7 +140,7 @@ impl From for InstrumentedObjectStoreMode { pub struct InstrumentedObjectStore { inner: Arc, instrument_mode: AtomicU8, - requests: Mutex>, + requests: Arc>>, } impl InstrumentedObjectStore { @@ -100,7 +149,7 @@ impl InstrumentedObjectStore { Self { inner: object_store, instrument_mode, - requests: Mutex::new(Vec::new()), + requests: Arc::new(Mutex::new(Vec::new())), } } @@ -137,7 +186,7 @@ impl InstrumentedObjectStore { op: Operation::Put, path: location.clone(), timestamp, - duration: Some(elapsed), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), size: Some(size), range: None, extra_display: None, @@ -160,7 +209,7 @@ impl InstrumentedObjectStore { op: Operation::Put, path: location.clone(), timestamp, - duration: Some(elapsed), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), size: None, range: None, extra_display: None, @@ -177,16 +226,26 @@ impl InstrumentedObjectStore { let timestamp = Utc::now(); let range = options.range.clone(); + let head = options.head; let start = Instant::now(); let ret = self.inner.get_opts(location, options).await?; let elapsed = start.elapsed(); + let (op, size) = if head { + (Operation::Head, None) + } else { + ( + Operation::Get, + Some((ret.range.end - ret.range.start) as usize), + ) + }; + self.requests.lock().push(RequestDetails { - op: Operation::Get, + op, path: location.clone(), timestamp, - duration: Some(elapsed), - size: Some((ret.range.end - ret.range.start) as usize), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size, range, extra_display: None, }); @@ -194,23 +253,30 @@ impl InstrumentedObjectStore { Ok(ret) } - async fn instrumented_delete(&self, location: &Path) -> Result<()> { + fn instrumented_delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + let requests_captured = Arc::clone(&self.requests); + let timestamp = Utc::now(); let start = Instant::now(); - self.inner.delete(location).await?; - let elapsed = start.elapsed(); - - self.requests.lock().push(RequestDetails { - op: Operation::Delete, - path: location.clone(), - timestamp, - duration: Some(elapsed), - size: None, - range: None, - extra_display: None, - }); - - Ok(()) + self.inner + .delete_stream(locations) + .and_then(move |location| { + let elapsed = start.elapsed(); + requests_captured.lock().push(RequestDetails { + op: Operation::Delete, + path: location.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: None, + }); + futures::future::ok(location) + }) + .boxed() } fn instrumented_list( @@ -218,19 +284,20 @@ impl InstrumentedObjectStore { prefix: Option<&Path>, ) -> BoxStream<'static, Result> { let timestamp = Utc::now(); - let ret = self.inner.list(prefix); + let inner_stream = self.inner.list(prefix); + let duration_nanos = Arc::new(AtomicU64::new(0)); self.requests.lock().push(RequestDetails { op: Operation::List, path: prefix.cloned().unwrap_or_else(|| Path::from("")), timestamp, - duration: None, // list returns a stream, so the duration isn't meaningful + duration_nanos: Arc::clone(&duration_nanos), size: None, range: None, extra_display: None, }); - ret + Box::pin(TimeToFirstItemStream::new(inner_stream, duration_nanos)) } async fn instrumented_list_with_delimiter( @@ -246,7 +313,7 @@ impl InstrumentedObjectStore { op: Operation::List, path: prefix.cloned().unwrap_or_else(|| Path::from("")), timestamp, - duration: Some(elapsed), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), size: None, range: None, extra_display: None, @@ -265,7 +332,7 @@ impl InstrumentedObjectStore { op: Operation::Copy, path: from.clone(), timestamp, - duration: Some(elapsed), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), size: None, range: None, extra_display: Some(format!("copy_to: {to}")), @@ -288,7 +355,7 @@ impl InstrumentedObjectStore { op: Operation::Copy, path: from.clone(), timestamp, - duration: Some(elapsed), + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), size: None, range: None, extra_display: Some(format!("copy_to: {to}")), @@ -296,29 +363,10 @@ impl InstrumentedObjectStore { Ok(()) } - - async fn instrumented_head(&self, location: &Path) -> Result { - let timestamp = Utc::now(); - let start = Instant::now(); - let ret = self.inner.head(location).await?; - let elapsed = start.elapsed(); - - self.requests.lock().push(RequestDetails { - op: Operation::Head, - path: location.clone(), - timestamp, - duration: Some(elapsed), - size: None, - range: None, - extra_display: None, - }); - - Ok(ret) - } } impl fmt::Display for InstrumentedObjectStore { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mode: InstrumentedObjectStoreMode = self.instrument_mode.load(Ordering::Relaxed).into(); write!( @@ -364,12 +412,15 @@ impl ObjectStore for InstrumentedObjectStore { self.inner.get_opts(location, options).await } - async fn delete(&self, location: &Path) -> Result<()> { + fn delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { if self.enabled() { - return self.instrumented_delete(location).await; + return self.instrumented_delete_stream(locations); } - self.inner.delete(location).await + self.inner.delete_stream(locations) } fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { @@ -388,28 +439,24 @@ impl ObjectStore for InstrumentedObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - if self.enabled() { - return self.instrumented_copy(from, to).await; - } - - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - if self.enabled() { - return self.instrumented_copy_if_not_exists(from, to).await; - } - - self.inner.copy_if_not_exists(from, to).await - } - - async fn head(&self, location: &Path) -> Result { + async fn copy_opts( + &self, + from: &Path, + to: &Path, + options: CopyOptions, + ) -> Result<()> { if self.enabled() { - return self.instrumented_head(location).await; + return match options.mode { + object_store::CopyMode::Create => { + self.instrumented_copy_if_not_exists(from, to).await + } + object_store::CopyMode::Overwrite => { + self.instrumented_copy(from, to).await + } + }; } - self.inner.head(location).await + self.inner.copy_opts(from, to, options).await } } @@ -425,32 +472,57 @@ pub enum Operation { } impl fmt::Display for Operation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:?}") } } /// Holds profiling details about individual requests made through an [`InstrumentedObjectStore`] -#[derive(Debug)] pub struct RequestDetails { op: Operation, path: Path, timestamp: chrono::DateTime, - duration: Option, + /// Duration stored as nanoseconds in an AtomicU64. 0 means not yet set. + duration_nanos: Arc, size: Option, range: Option, extra_display: Option, } +impl fmt::Debug for RequestDetails { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequestDetails") + .field("op", &self.op) + .field("path", &self.path) + .field("timestamp", &self.timestamp) + .field("duration", &self.duration()) + .field("size", &self.size) + .field("range", &self.range) + .field("extra_display", &self.extra_display) + .finish() + } +} + +impl RequestDetails { + fn duration(&self) -> Option { + let nanos = self.duration_nanos.load(Ordering::Acquire); + if nanos == 0 { + None + } else { + Some(Duration::from_nanos(nanos)) + } + } +} + impl fmt::Display for RequestDetails { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut output_parts = vec![format!( "{} operation={:?}", self.timestamp.to_rfc3339(), self.op )]; - if let Some(d) = self.duration { + if let Some(d) = self.duration() { output_parts.push(format!("duration={:.6}s", d.as_secs_f32())); } if let Some(s) = self.size { @@ -637,7 +709,7 @@ impl RequestSummary { } fn push(&mut self, request: &RequestDetails) { self.count += 1; - if let Some(dur) = request.duration { + if let Some(dur) = request.duration() { self.duration_stats.get_or_insert_default().push(dur) } if let Some(size) = request.size { @@ -758,6 +830,7 @@ impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { #[cfg(test)] mod tests { + use futures::StreamExt; use object_store::WriteMultipart; use super::*; @@ -782,9 +855,11 @@ mod tests { "TRaCe".parse().unwrap(), InstrumentedObjectStoreMode::Trace )); - assert!("does_not_exist" - .parse::() - .is_err()); + assert!( + "does_not_exist" + .parse::() + .is_err() + ); assert!(matches!(0.into(), InstrumentedObjectStoreMode::Disabled)); assert!(matches!(1.into(), InstrumentedObjectStoreMode::Summary)); @@ -850,7 +925,7 @@ mod tests { let request = requests.pop().unwrap(); assert_eq!(request.op, Operation::Get); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert_eq!(request.size, Some(9)); assert_eq!(request.range, None); assert!(request.extra_display.is_none()); @@ -879,7 +954,7 @@ mod tests { let request = requests.pop().unwrap(); assert_eq!(request.op, Operation::Delete); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); @@ -896,18 +971,58 @@ mod tests { instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); assert!(instrumented.requests.lock().is_empty()); - let _ = instrumented.list(Some(&path)); + let mut stream = instrumented.list(Some(&path)); + // Sleep between stream creation and first poll to verify the timer + // starts on first poll, not at stream creation. + let delay = Duration::from_millis(50); + tokio::time::sleep(delay).await; + let _ = stream.next().await; assert_eq!(instrumented.requests.lock().len(), 1); let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::List); assert_eq!(request.path, path); - assert!(request.duration.is_none()); + let duration = request + .duration() + .expect("duration should be set after consuming stream"); + assert!( + duration < delay, + "duration {duration:?} should exclude the {delay:?} sleep before first poll" + ); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); } + #[tokio::test] + async fn time_to_first_item_stream_captures_inner_latency() { + let inner_delay = Duration::from_millis(50); + let inner_stream = futures::stream::once(async move { + tokio::time::sleep(inner_delay).await; + Ok(ObjectMeta { + location: Path::from("test"), + last_modified: Utc::now(), + size: 0, + e_tag: None, + version: None, + }) + }) + .boxed(); + + let duration_nanos = Arc::new(AtomicU64::new(0)); + let mut stream = Box::pin(TimeToFirstItemStream::new( + inner_stream, + Arc::clone(&duration_nanos), + )); + let _ = stream.next().await; + + let recorded = Duration::from_nanos(duration_nanos.load(Ordering::Acquire)); + assert!( + recorded >= inner_delay, + "recorded duration {recorded:?} should be >= inner stream delay {inner_delay:?}" + ); + } + #[tokio::test] async fn instrumented_store_list_with_delimiter() { let (instrumented, path) = setup_test_store().await; @@ -925,7 +1040,7 @@ mod tests { let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::List); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); @@ -956,7 +1071,7 @@ mod tests { let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::Put); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert_eq!(request.size.unwrap(), size); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); @@ -991,7 +1106,7 @@ mod tests { let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::Put); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); @@ -1019,7 +1134,7 @@ mod tests { let request = requests.pop().unwrap(); assert_eq!(request.op, Operation::Copy); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert_eq!( @@ -1058,7 +1173,7 @@ mod tests { let request = requests.pop().unwrap(); assert_eq!(request.op, Operation::Copy); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert_eq!( @@ -1088,7 +1203,7 @@ mod tests { let request = requests.pop().unwrap(); assert_eq!(request.op, Operation::Head); assert_eq!(request.path, path); - assert!(request.duration.is_some()); + assert!(request.duration().is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); @@ -1100,7 +1215,9 @@ mod tests { op: Operation::Get, path: Path::from("test"), timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), - duration: Some(Duration::new(5, 0)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::new(5, 0).as_nanos() as u64 + )), size: Some(10), range: Some((..10).into()), extra_display: Some(String::from("extra info")), @@ -1127,7 +1244,9 @@ mod tests { op: Operation::Get, path: Path::from("test1"), timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), - duration: Some(Duration::from_secs(5)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(5).as_nanos() as u64 + )), size: Some(100), range: None, extra_display: None, @@ -1147,7 +1266,9 @@ mod tests { op: Operation::Get, path: Path::from("test2"), timestamp: chrono::DateTime::from_timestamp(1, 0).unwrap(), - duration: Some(Duration::from_secs(8)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(8).as_nanos() as u64 + )), size: Some(150), range: None, extra_display: None, @@ -1156,7 +1277,9 @@ mod tests { op: Operation::Get, path: Path::from("test3"), timestamp: chrono::DateTime::from_timestamp(2, 0).unwrap(), - duration: Some(Duration::from_secs(2)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(2).as_nanos() as u64 + )), size: Some(50), range: None, extra_display: None, @@ -1175,7 +1298,9 @@ mod tests { op: Operation::Put, path: Path::from("test4"), timestamp: chrono::DateTime::from_timestamp(3, 0).unwrap(), - duration: Some(Duration::from_millis(200)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_millis(200).as_nanos() as u64, + )), size: Some(75), range: None, extra_display: None, @@ -1200,7 +1325,9 @@ mod tests { op: Operation::Get, path: Path::from("test1"), timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), - duration: Some(Duration::from_secs(3)), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(3).as_nanos() as u64 + )), size: None, range: None, extra_display: None, @@ -1222,7 +1349,7 @@ mod tests { op: Operation::Get, path: Path::from("test1"), timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), - duration: None, + duration_nanos: Arc::new(AtomicU64::new(0)), size: Some(200), range: None, extra_display: None, @@ -1244,7 +1371,7 @@ mod tests { op: Operation::Get, path: Path::from("test1"), timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), - duration: None, + duration_nanos: Arc::new(AtomicU64::new(0)), size: None, range: None, extra_display: None, diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 56bdb15a315d9..0443a7a289602 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -97,7 +97,7 @@ fn keep_only_maxrows(s: &str, maxrows: usize) -> String { let last_line = &lines[lines.len() - 1]; // bottom border line let spaces = last_line.len().saturating_sub(4); - let dotted_line = format!("| .{: SchemaRef { + let fields: Vec = (0..10) + .map(|i| Field::new(format!("c{i}"), DataType::Int32, false)) + .collect(); + Arc::new(Schema::new(fields)) + } + + /// return a batch with many columns and three rows + fn wide_column_batch() -> RecordBatch { + let arrays: Vec> = (0..10) + .map(|_| Arc::new(Int32Array::from(vec![0, 1, 2])) as _) + .collect(); + RecordBatch::try_new(wide_column_schema(), arrays).unwrap() + } + /// Slice the record batch into 2 batches - fn split_batch(batch: RecordBatch) -> Vec { + fn split_batch(batch: &RecordBatch) -> Vec { assert!(batch.num_rows() > 1); let split = batch.num_rows() / 2; vec![ diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 93d1d450fd82b..d0810cb034df1 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -28,8 +28,8 @@ use crate::print_format::PrintFormat; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion::common::instant::Instant; use datafusion::common::DataFusionError; +use datafusion::common::instant::Instant; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; @@ -55,8 +55,10 @@ impl FromStr for MaxRows { Ok(Self::Unlimited) } else { match maxrows.parse::() { - Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), + Ok(nrows) => Ok(Self::Limited(nrows)), + _ => Err(format!( + "Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit." + )), } } } @@ -113,7 +115,7 @@ impl PrintOptions { row_count: usize, format_options: &FormatOptions, ) -> Result<()> { - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); self.format.print_batches( @@ -135,7 +137,7 @@ impl PrintOptions { query_start_time, ); - self.write_output(&mut writer, formatted_exec_details) + self.write_output(&mut writer, &formatted_exec_details) } /// Print the stream to stdout using the specified format @@ -151,7 +153,7 @@ impl PrintOptions { )); }; - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); let mut row_count = 0_usize; @@ -177,13 +179,13 @@ impl PrintOptions { query_start_time, ); - self.write_output(&mut writer, formatted_exec_details) + self.write_output(&mut writer, &formatted_exec_details) } fn write_output( &self, writer: &mut W, - formatted_exec_details: String, + formatted_exec_details: &str, ) -> Result<()> { if !self.quiet { writeln!(writer, "{formatted_exec_details}")?; @@ -235,11 +237,11 @@ mod tests { let mut print_output: Vec = Vec::new(); let exec_out = String::from("Formatted Exec Output"); - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; assert!(print_output.is_empty()); print_options.quiet = false; - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; let out_str: String = print_output .clone() .try_into() @@ -251,7 +253,7 @@ mod tests { print_options .instrumented_registry .set_instrument_mode(InstrumentedObjectStoreMode::Trace); - print_options.write_output(&mut print_output, exec_out.clone())?; + print_options.write_output(&mut print_output, &exec_out)?; let out_str: String = print_output .clone() .try_into() diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index c1395aa4f562c..be4a2ad4fe197 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -20,14 +20,17 @@ use std::process::Command; use rstest::rstest; use async_trait::async_trait; -use insta::{glob, Settings}; +use insta::internals::SettingsBindDropGuard; +use insta::{Settings, glob}; use insta_cmd::{assert_cmd_snapshot, get_cargo_bin}; use std::path::PathBuf; use std::{env, fs}; -use testcontainers::core::{CmdWaitFor, ExecCommand, Mount}; -use testcontainers::runners::AsyncRunner; -use testcontainers::{ContainerAsync, ImageExt, TestcontainersError}; use testcontainers_modules::minio; +use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand, Mount}; +use testcontainers_modules::testcontainers::runners::AsyncRunner; +use testcontainers_modules::testcontainers::{ + ContainerAsync, ImageExt, TestcontainersError, +}; fn cli() -> Command { Command::new(get_cargo_bin("datafusion-cli")) @@ -42,7 +45,7 @@ fn make_settings() -> Settings { settings } -async fn setup_minio_container() -> ContainerAsync { +async fn setup_minio_container() -> Result, String> { const MINIO_ROOT_USER: &str = "TEST-DataFusionLogin"; const MINIO_ROOT_PASSWORD: &str = "TEST-DataFusionPassword"; @@ -97,25 +100,23 @@ async fn setup_minio_container() -> ContainerAsync { let stdout = container.stdout_to_vec().await.unwrap_or_default(); let stderr = container.stderr_to_vec().await.unwrap_or_default(); - panic!( + return Err(format!( "Failed to execute command: {}\nError: {}\nStdout: {:?}\nStderr: {:?}", cmd_ref, e, String::from_utf8_lossy(&stdout), String::from_utf8_lossy(&stderr) - ); + )); } } - container + Ok(container) } - Err(TestcontainersError::Client(e)) => { - panic!("Failed to start MinIO container. Ensure Docker is running and accessible: {e}"); - } - Err(e) => { - panic!("Failed to start MinIO container: {e}"); - } + Err(TestcontainersError::Client(e)) => Err(format!( + "Failed to start MinIO container. Ensure Docker is running and accessible: {e}" + )), + Err(e) => Err(format!("Failed to start MinIO container: {e}")), } } @@ -215,6 +216,42 @@ fn test_cli_top_memory_consumers<'a>( #[case] snapshot_name: &str, #[case] top_memory_consumers: impl IntoIterator, ) { + let _bound = bind_to_settings(snapshot_name); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args(["--memory-limit", "10M", "--command", sql]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + +#[rstest] +#[case("no_track", ["--top-memory-consumers", "0"])] +#[case("top2", ["--top-memory-consumers", "2"])] +#[test] +fn test_cli_top_memory_consumers_with_mem_pool_type<'a>( + #[case] snapshot_name: &str, + #[case] top_memory_consumers: impl IntoIterator, +) { + let _bound = bind_to_settings(snapshot_name); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args([ + "--memory-limit", + "10M", + "--mem-pool-type", + "fair", + "--command", + sql, + ]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + +fn bind_to_settings(snapshot_name: &str) -> SettingsBindDropGuard { let mut settings = make_settings(); settings.set_snapshot_suffix(snapshot_name); @@ -224,20 +261,45 @@ fn test_cli_top_memory_consumers<'a>( "Consumer(can spill: bool) consumed XB, peak XB", ); settings.add_filter( - r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total memory pool: '.*?'", "Error: Failed to allocate ", ); settings.add_filter( - r"Resources exhausted: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + r"Resources exhausted: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total memory pool: '.*?'", "Resources exhausted: Failed to allocate", ); + settings.bind_to_scope() +} + +#[test] +fn test_cli_with_unbounded_memory_pool() { + let mut settings = make_settings(); + + settings.set_snapshot_suffix("default"); + let _bound = settings.bind_to_scope(); let mut cmd = cli(); let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; - cmd.args(["--memory-limit", "10M", "--command", sql]); - cmd.args(top_memory_consumers); + cmd.args(["--maxrows", "10", "--command", sql]); + + assert_cmd_snapshot!(cmd); +} + +#[test] +fn test_cli_wide_result_set_no_crash() { + let mut settings = make_settings(); + + settings.set_snapshot_suffix("wide_result_set"); + + let _bound = settings.bind_to_scope(); + + let mut cmd = cli(); + let sql = "SELECT v1 as c0, v1+1 as c1, v1+2 as c2, v1+3 as c3, v1+4 as c4, \ + v1+5 as c5, v1+6 as c6, v1+7 as c7, v1+8 as c8, v1+9 as c9 \ + FROM generate_series(1, 100) as t1(v1);"; + cmd.args(["--maxrows", "5", "--command", sql]); assert_cmd_snapshot!(cmd); } @@ -249,7 +311,14 @@ async fn test_cli() { return; } - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let settings = make_settings(); let _bound = settings.bind_to_scope(); @@ -258,13 +327,15 @@ async fn test_cli() { glob!("sql/integration/*.sql", |path| { let input = fs::read_to_string(path).unwrap(); - assert_cmd_snapshot!(cli() - .env_clear() - .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") - .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") - .env("AWS_ENDPOINT", format!("http://localhost:{port}")) - .env("AWS_ALLOW_HTTP", "true") - .pass_stdin(input)) + assert_cmd_snapshot!( + cli() + .env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + .pass_stdin(input) + ) }); } @@ -280,7 +351,14 @@ async fn test_aws_options() { let settings = make_settings(); let _bound = settings.bind_to_scope(); - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let port = container.get_host_port_ipv4(9000).await.unwrap(); let input = format!( @@ -328,10 +406,12 @@ SELECT COUNT(*) FROM hits; "# ); - assert_cmd_snapshot!(cli() - .env("RUST_LOG", "warn") - .env_remove("AWS_ENDPOINT") - .pass_stdin(input)); + assert_cmd_snapshot!( + cli() + .env("RUST_LOG", "warn") + .env_remove("AWS_ENDPOINT") + .pass_stdin(input) + ); } /// Ensure backtrace will be printed, if executing `datafusion-cli` with a query @@ -351,14 +431,12 @@ fn test_backtrace_output(#[case] query: &str) { let output = cmd.output().expect("Failed to execute command"); let stdout = String::from_utf8_lossy(&output.stdout); let stderr = String::from_utf8_lossy(&output.stderr); - let combined_output = format!("{}{}", stdout, stderr); + let combined_output = format!("{stdout}{stderr}"); // Assert that the output includes literal 'backtrace' assert!( combined_output.to_lowercase().contains("backtrace"), - "Expected output to contain 'backtrace', but got stdout: '{}' stderr: '{}'", - stdout, - stderr + "Expected output to contain 'backtrace', but got stdout: '{stdout}' stderr: '{stderr}'" ); } @@ -369,7 +447,14 @@ async fn test_s3_url_fallback() { return; } - let container = setup_minio_container().await; + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let mut settings = make_settings(); settings.set_snapshot_suffix("s3_url_fallback"); @@ -399,8 +484,14 @@ async fn test_object_store_profiling() { return; } - let container = setup_minio_container().await; - + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; let mut settings = make_settings(); // as the object store profiling contains timestamps and durations, we must @@ -450,7 +541,7 @@ SELECT * from CARS LIMIT 1; #[async_trait] trait MinioCommandExt { async fn with_minio(&mut self, container: &ContainerAsync) - -> &mut Self; + -> &mut Self; } #[async_trait] diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 6b3a247dd7b82..5f43ca88dc9d7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -7,7 +7,6 @@ info: - EXPLAIN SELECT 123 env: DATAFUSION_EXPLAIN_FORMAT: pgjson -snapshot_kind: text --- success: true exit_code: 0 @@ -19,19 +18,19 @@ exit_code: 0 | logical_plan | [ | | | { | | | "Plan": { | -| | "Expressions": [ | -| | "Int64(123)" | -| | ], | | | "Node Type": "Projection", | -| | "Output": [ | +| | "Expressions": [ | | | "Int64(123)" | | | ], | | | "Plans": [ | | | { | | | "Node Type": "EmptyRelation", | -| | "Output": [], | -| | "Plans": [] | +| | "Plans": [], | +| | "Output": [] | | | } | +| | ], | +| | "Output": [ | +| | "Int64(123)" | | | ] | | | } | | | } | diff --git a/datafusion-cli/tests/snapshots/cli_format@automatic.snap b/datafusion-cli/tests/snapshots/cli_format@automatic.snap index 2591f493e90a8..76b14d9a3a924 100644 --- a/datafusion-cli/tests/snapshots/cli_format@automatic.snap +++ b/datafusion-cli/tests/snapshots/cli_format@automatic.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@csv.snap b/datafusion-cli/tests/snapshots/cli_format@csv.snap index c41b042298eb0..2c969bd91d121 100644 --- a/datafusion-cli/tests/snapshots/cli_format@csv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@csv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@json.snap b/datafusion-cli/tests/snapshots/cli_format@json.snap index 8f804a337cce5..22a9cc4657a91 100644 --- a/datafusion-cli/tests/snapshots/cli_format@json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap index 7b4ce1e2530cf..513bcb7372ca6 100644 --- a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@table.snap b/datafusion-cli/tests/snapshots/cli_format@table.snap index 99914182462aa..8677847588385 100644 --- a/datafusion-cli/tests/snapshots/cli_format@table.snap +++ b/datafusion-cli/tests/snapshots/cli_format@table.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@tsv.snap b/datafusion-cli/tests/snapshots/cli_format@tsv.snap index 968268c31dd55..c56e60fcab155 100644 --- a/datafusion-cli/tests/snapshots/cli_format@tsv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@tsv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap index c27d527df0b6a..9fd07fa6f4e1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap index 46ee6be64f624..8620f6da84488 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap @@ -5,7 +5,6 @@ info: args: - "--command" - EXPLAIN SELECT 123 -snapshot_kind: text --- success: true exit_code: 0 diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap index 7c44e41729a17..df3a10b6bb54b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap index 3b975bb6a927d..a394458768d1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap index 89b646a531f8b..c34e1202f55da 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -14,8 +14,8 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by -Resources exhausted: Failed to allocate +Resources exhausted: Failed to allocate additional 128.0 KB for ExternalSorter[0] with 0.0 B already allocated for this reservation - 0.0 B remain available for the total memory pool: greedy(used: 10.0 MB, pool_size: 10.0 MB) ----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap index 62f864b3adb6e..ebf7a540d8d44 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -14,11 +14,11 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, Consumer(can spill: bool) consumed XB, peak XB. -Error: Failed to allocate +Error: Failed to allocate additional 128.0 KB for ExternalSorter[0] with 0.0 B already allocated for this reservation - 0.0 B remain available for the total memory pool: greedy(used: 10.0 MB, pool_size: 10.0 MB) ----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap index 9845d095c9180..9e279ca93ddcd 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -12,12 +12,12 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: Consumer(can spill: bool) consumed XB, peak XB, Consumer(can spill: bool) consumed XB, peak XB, Consumer(can spill: bool) consumed XB, peak XB. -Error: Failed to allocate +Error: Failed to allocate additional 128.0 KB for ExternalSorter[0] with 0.0 B already allocated for this reservation - 0.0 B remain available for the total memory pool: greedy(used: 10.0 MB, pool_size: 10.0 MB) ----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap new file mode 100644 index 0000000000000..9a228fcfb6e93 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap @@ -0,0 +1,23 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--mem-pool-type" + - fair + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "0" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. +caused by +Resources exhausted: Failed to allocate additional 128.0 KB for ExternalSorter[0] with 0.0 B already allocated for this reservation - 0.0 B remain available for the total memory pool: fair(pool_size: 10.0 MB) + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap new file mode 100644 index 0000000000000..d7f964a339313 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap @@ -0,0 +1,26 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--mem-pool-type" + - fair + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "2" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. +caused by +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. +Error: Failed to allocate additional 128.0 KB for ExternalSorter[0] with 0.0 B already allocated for this reservation - 0.0 B remain available for the total memory pool: fair(pool_size: 10.0 MB) + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_wide_result_set_no_crash@wide_result_set.snap b/datafusion-cli/tests/snapshots/cli_wide_result_set_no_crash@wide_result_set.snap new file mode 100644 index 0000000000000..30b34f3c12baa --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_wide_result_set_no_crash@wide_result_set.snap @@ -0,0 +1,32 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +assertion_line: 307 +info: + program: datafusion-cli + args: + - "--maxrows" + - "5" + - "--command" + - "SELECT v1 as c0, v1+1 as c1, v1+2 as c2, v1+3 as c3, v1+4 as c4, v1+5 as c5, v1+6 as c6, v1+7 as c7, v1+8 as c8, v1+9 as c9 FROM generate_series(1, 100) as t1(v1);" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] ++----+----+----+----+----+----+----+----+----+----+ +| c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | ++----+----+----+----+----+----+----+----+----+----+ +| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | +| 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | +| 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | +| 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | +| 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | +| . | +| . | +| . | ++----+----+----+----+----+----+----+----+----+----+ +100 row(s) fetched. (First 5 displayed. Use --maxrows to adjust) +[ELAPSED] + + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap b/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap new file mode 100644 index 0000000000000..7bdcd63dc7be6 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap @@ -0,0 +1,36 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--maxrows" + - "10" + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] ++----+ +| v1 | ++----+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | +| 10 | +| . | +| . | +| . | ++----+ +500000 row(s) fetched. (First 10 displayed. Use --maxrows to adjust) +[ELAPSED] + + +----- stderr ----- diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index bb0525e57753b..e56f5ad6b8ca7 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -29,63 +29,50 @@ license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true -[[example]] -name = "flight_sql_server" -path = "examples/flight/flight_sql_server.rs" - -[[example]] -name = "flight_server" -path = "examples/flight/flight_server.rs" - -[[example]] -name = "flight_client" -path = "examples/flight/flight_client.rs" - -[[example]] -name = "dataframe_to_s3" -path = "examples/external_dependency/dataframe-to-s3.rs" - -[[example]] -name = "query_aws_s3" -path = "examples/external_dependency/query-aws-s3.rs" - -[[example]] -name = "custom_file_casts" -path = "examples/custom_file_casts.rs" +[dependencies] +arrow = { workspace = true } +arrow-schema = { workspace = true } +datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion-common = { workspace = true } +nom = "8.0.0" +tempfile = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } [dev-dependencies] -arrow = { workspace = true } -# arrow_schema is required for record_batch! macro :sad: arrow-flight = { workspace = true } -arrow-schema = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" -datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } -datafusion-ffi = { workspace = true } +datafusion-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } +datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } +insta = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +serde = { version = "1", features = ["derive"] } serde_json = { workspace = true } -tempfile = { workspace = true } +strum = { workspace = true } +strum_macros = { workspace = true } test-utils = { path = "../test-utils" } -tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.18" +uuid = { workspace = true } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index f1bcbcce82004..073f269d4a35d 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -39,59 +39,193 @@ git submodule update --init # Change to the examples directory cd datafusion-examples/examples -# Run the `dataframe` example: -# ... use the equivalent for other examples -cargo run --example dataframe +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe ``` -## Single Process - -- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) -- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) -- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) -- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files -- [`async_udf.rs`](examples/async_udf.rs): Define and invoke an asynchronous User Defined Scalar Function (UDF) -- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) -- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog -- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization -- [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file -- [`csv_json_opener.rs`](examples/csv_json_opener.rs): Use low level `FileOpener` APIs to read CSV/JSON into Arrow `RecordBatch`es -- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) -- [`custom_file_casts.rs`](examples/custom_file_casts.rs): Implement custom casting rules to adapt file schemas -- [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. -- [`default_column_values.rs`](examples/default_column_values.rs): Implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter -- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results (Arrow ArrayRefs) into Rust structs -- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s -- [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. -- [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients -- [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros -- [`memory_pool_tracking.rs`](examples/memory_pool_tracking.rs): Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages -- [`memory_pool_execution_plan.rs`](examples/memory_pool_execution_plan.rs): Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling -- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates -- [`parquet_embedded_index.rs`](examples/parquet_embedded_index.rs): Store a custom index inside a Parquet file and use it to speed up queries -- [`parquet_encrypted.rs`](examples/parquet_encrypted.rs): Read and write encrypted Parquet files using DataFusion -- [`parquet_encrypted_with_kms.rs`](examples/parquet_encrypted_with_kms.rs): Read and write encrypted Parquet files using an encryption factory -- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries -- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution -- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. -- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` -- [`planner_api.rs`](examples/planner_api.rs) APIs to manipulate logical and physical plans -- [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics -- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 -- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP -- [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`remote_catalog.rs`](examples/regexp.rs): Examples of interfacing with a remote catalog (e.g. over a network) -- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) -- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures -- [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings -- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` -- [`sql_query.rs`](examples/memtable.rs): Query data using SQL (in memory `RecordBatches`, local Parquet files) -- [`date_time_function.rs`](examples/date_time_function.rs): Examples of date-time related functions and queries. - -## Distributed - -- [`flight_client.rs`](examples/flight/flight_client.rs) and [`flight_server.rs`](examples/flight/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. +## Builtin Functions Examples + +### Group: `builtin_functions` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------- | +| date_time | [`builtin_functions/date_time.rs`](examples/builtin_functions/date_time.rs) | Examples of date-time related functions and queries | +| function_factory | [`builtin_functions/function_factory.rs`](examples/builtin_functions/function_factory.rs) | Register `CREATE FUNCTION` handler to implement SQL macros | +| regexp | [`builtin_functions/regexp.rs`](examples/builtin_functions/regexp.rs) | Examples of using regular expression functions | + +## Custom Data Source Examples + +### Group: `custom_data_source` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | +| adapter_serialization | [`custom_data_source/adapter_serialization.rs`](examples/custom_data_source/adapter_serialization.rs) | Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception | +| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | +| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | +| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | +| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | +| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | +| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | +| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | + +## Data IO Examples + +### Group: `data_io` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------- | +| catalog | [`data_io/catalog.rs`](examples/data_io/catalog.rs) | Register tables into a custom catalog | +| in_memory_object_store | [`data_io/in_memory_object_store.rs`](examples/data_io/in_memory_object_store.rs) | Read CSV from an in-memory object store (pattern applies to JSON/Parquet) | +| json_shredding | [`data_io/json_shredding.rs`](examples/data_io/json_shredding.rs) | Implement filter rewriting for JSON shredding | +| parquet_adv_idx | [`data_io/parquet_advanced_index.rs`](examples/data_io/parquet_advanced_index.rs) | Create a secondary index across multiple parquet files | +| parquet_emb_idx | [`data_io/parquet_embedded_index.rs`](examples/data_io/parquet_embedded_index.rs) | Store a custom index inside Parquet files | +| parquet_enc | [`data_io/parquet_encrypted.rs`](examples/data_io/parquet_encrypted.rs) | Read & write encrypted Parquet files | +| parquet_enc_with_kms | [`data_io/parquet_encrypted_with_kms.rs`](examples/data_io/parquet_encrypted_with_kms.rs) | Encrypted Parquet I/O using a KMS-backed factory | +| parquet_exec_visitor | [`data_io/parquet_exec_visitor.rs`](examples/data_io/parquet_exec_visitor.rs) | Extract statistics by visiting an ExecutionPlan | +| parquet_idx | [`data_io/parquet_index.rs`](examples/data_io/parquet_index.rs) | Create a secondary index | +| query_http_csv | [`data_io/query_http_csv.rs`](examples/data_io/query_http_csv.rs) | Query CSV files via HTTP | +| remote_catalog | [`data_io/remote_catalog.rs`](examples/data_io/remote_catalog.rs) | Interact with a remote catalog | + +## DataFrame Examples + +### Group: `dataframe` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------- | +| cache_factory | [`dataframe/cache_factory.rs`](examples/dataframe/cache_factory.rs) | Custom lazy caching for DataFrames using `CacheFactory` | +| dataframe | [`dataframe/dataframe.rs`](examples/dataframe/dataframe.rs) | Query DataFrames from various sources and write output | +| deserialize_to_struct | [`dataframe/deserialize_to_struct.rs`](examples/dataframe/deserialize_to_struct.rs) | Convert Arrow arrays into Rust structs | + +## Execution Monitoring Examples + +### Group: `execution_monitoring` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------ | ------------------------------------------------------------------------------------------------------------------- | ---------------------------------------- | +| mem_pool_exec_plan | [`execution_monitoring/memory_pool_execution_plan.rs`](examples/execution_monitoring/memory_pool_execution_plan.rs) | Memory-aware ExecutionPlan with spilling | +| mem_pool_tracking | [`execution_monitoring/memory_pool_tracking.rs`](examples/execution_monitoring/memory_pool_tracking.rs) | Demonstrates memory tracking | +| tracing | [`execution_monitoring/tracing.rs`](examples/execution_monitoring/tracing.rs) | Demonstrates tracing integration | + +## Extension Types Examples + +### Group: `extension_types` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ----------- | --------------------------------------------------------------------------- | ------------------------------------ | +| temperature | [`extension_types/temperature.rs`](examples/extension_types/temperature.rs) | Extension type for temperature data. | + +## External Dependency Examples + +### Group: `external_dependency` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------------- | ---------------------------------------- | +| dataframe_to_s3 | [`external_dependency/dataframe_to_s3.rs`](examples/external_dependency/dataframe_to_s3.rs) | Query DataFrames and write results to S3 | +| query_aws_s3 | [`external_dependency/query_aws_s3.rs`](examples/external_dependency/query_aws_s3.rs) | Query S3-backed data using object_store | + +## Flight Examples + +### Group: `flight` + +#### Category: Distributed + +| Subcommand | File Path | Description | +| ---------- | ------------------------------------------------------- | ------------------------------------------------------ | +| client | [`flight/client.rs`](examples/flight/client.rs) | Execute SQL queries via Arrow Flight protocol | +| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | +| sql_server | [`flight/sql_server.rs`](examples/flight/sql_server.rs) | Standalone SQL server for JDBC clients | + +## Proto Examples + +### Group: `proto` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------------ | --------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | +| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| expression_deduplication | [`proto/expression_deduplication.rs`](examples/proto/expression_deduplication.rs) | Example of expression caching/deduplication using the codec decorator pattern | + +## Query Planning Examples + +### Group: `query_planning` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| -------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------ | +| analyzer_rule | [`query_planning/analyzer_rule.rs`](examples/query_planning/analyzer_rule.rs) | Custom AnalyzerRule to change query semantics | +| expr_api | [`query_planning/expr_api.rs`](examples/query_planning/expr_api.rs) | Create, execute, analyze, and coerce Exprs | +| optimizer_rule | [`query_planning/optimizer_rule.rs`](examples/query_planning/optimizer_rule.rs) | Replace predicates via a custom OptimizerRule | +| parse_sql_expr | [`query_planning/parse_sql_expr.rs`](examples/query_planning/parse_sql_expr.rs) | Parse SQL into DataFusion Expr | +| plan_to_sql | [`query_planning/plan_to_sql.rs`](examples/query_planning/plan_to_sql.rs) | Generate SQL from expressions or plans | +| planner_api | [`query_planning/planner_api.rs`](examples/query_planning/planner_api.rs) | APIs for logical and physical plan manipulation | +| pruning | [`query_planning/pruning.rs`](examples/query_planning/pruning.rs) | Use pruning to skip irrelevant files | +| thread_pools | [`query_planning/thread_pools.rs`](examples/query_planning/thread_pools.rs) | Configure custom thread pools for DataFusion execution | + +## Relation Planner Examples + +### Group: `relation_planner` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------- | ------------------------------------------ | +| match_recognize | [`relation_planner/match_recognize.rs`](examples/relation_planner/match_recognize.rs) | Implement MATCH_RECOGNIZE pattern matching | +| pivot_unpivot | [`relation_planner/pivot_unpivot.rs`](examples/relation_planner/pivot_unpivot.rs) | Implement PIVOT / UNPIVOT | +| table_sample | [`relation_planner/table_sample.rs`](examples/relation_planner/table_sample.rs) | Implement TABLESAMPLE | + +## SQL Ops Examples + +### Group: `sql_ops` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ----------------- | ----------------------------------------------------------------------- | -------------------------------------------------- | +| analysis | [`sql_ops/analysis.rs`](examples/sql_ops/analysis.rs) | Analyze SQL queries | +| custom_sql_parser | [`sql_ops/custom_sql_parser.rs`](examples/sql_ops/custom_sql_parser.rs) | Implement a custom SQL parser to extend DataFusion | +| frontend | [`sql_ops/frontend.rs`](examples/sql_ops/frontend.rs) | Build LogicalPlans from SQL | +| query | [`sql_ops/query.rs`](examples/sql_ops/query.rs) | Query data using SQL | + +## UDF Examples + +### Group: `udf` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ----------------------------------------------------------- | ----------------------------------------------- | +| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | +| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | +| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | +| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | +| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | +| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | +| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | +| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | +| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example | diff --git a/datafusion-examples/data/README.md b/datafusion-examples/data/README.md new file mode 100644 index 0000000000000..e8296a8856e60 --- /dev/null +++ b/datafusion-examples/data/README.md @@ -0,0 +1,25 @@ + + +## Example datasets + +| Filename | Path | Description | +| ----------- | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `cars.csv` | [`data/csv/cars.csv`](./csv/cars.csv) | Time-series–like dataset containing car identifiers, speed values, and timestamps. Used in window function and time-based query examples (e.g. ordering, window frames). | +| `regex.csv` | [`data/csv/regex.csv`](./csv/regex.csv) | Dataset for regular expression examples. Contains input values, regex patterns, replacement strings, and optional flags. Covers ASCII, Unicode, and locale-specific text processing. | diff --git a/datafusion-examples/data/csv/cars.csv b/datafusion-examples/data/csv/cars.csv new file mode 100644 index 0000000000000..bc40f3b01e7a5 --- /dev/null +++ b/datafusion-examples/data/csv/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion-examples/data/csv/regex.csv b/datafusion-examples/data/csv/regex.csv new file mode 100644 index 0000000000000..b249c39522b60 --- /dev/null +++ b/datafusion-examples/data/csv/regex.csv @@ -0,0 +1,12 @@ +values,patterns,replacement,flags +abc,^(a),bb\1bb,i +ABC,^(A).*,B,i +aBc,(b|d),e,i +AbC,(B|D),e, +aBC,^(b|c),d, +4000,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +4010,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +Düsseldorf,[\p{Letter}-]+,München, +Москва,[\p{L}-]+,Moscow, +Köln,[a-zA-Z]ö[a-zA-Z]{2},Koln, +اليوم,^\p{Arabic}+$,Today, \ No newline at end of file diff --git a/datafusion-examples/examples/date_time_functions.rs b/datafusion-examples/examples/builtin_functions/date_time.rs similarity index 96% rename from datafusion-examples/examples/date_time_functions.rs rename to datafusion-examples/examples/builtin_functions/date_time.rs index 2628319ae31f0..08d4bc6e29978 100644 --- a/datafusion-examples/examples/date_time_functions.rs +++ b/datafusion-examples/examples/builtin_functions/date_time.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::array::{Date32Array, Int32Array}; @@ -26,8 +28,20 @@ use datafusion::common::assert_contains; use datafusion::error::Result; use datafusion::prelude::*; -#[tokio::main] -async fn main() -> Result<()> { +/// Example: Working with Date and Time Functions +/// +/// This example demonstrates how to work with various date and time +/// functions in DataFusion using both the DataFrame API and SQL queries. +/// +/// It includes: +/// - `make_date`: building `DATE` values from year, month, and day columns +/// - `to_date`: converting string expressions into `DATE` values +/// - `to_timestamp`: parsing strings or numeric values into `TIMESTAMP`s +/// - `to_char`: formatting dates, timestamps, and durations as strings +/// +/// Together, these examples show how to create, convert, and format temporal +/// data using DataFusion’s built-in functions. +pub async fn date_time() -> Result<()> { query_make_date().await?; query_to_date().await?; query_to_timestamp().await?; @@ -167,12 +181,13 @@ async fn query_make_date() -> Result<()> { // invalid column values will result in an error let result = ctx - .sql("select make_date(2024, null, 23)") + .sql("select make_date(2024, '', 23)") .await? .collect() .await; - let expected = "Execution error: Unable to parse date from null/empty value"; + let expected = + "Arrow error: Cast error: Cannot cast string '' to value of Int32 type"; assert_contains!(result.unwrap_err().to_string(), expected); // invalid date values will also result in an error @@ -182,7 +197,7 @@ async fn query_make_date() -> Result<()> { .collect() .await; - let expected = "Execution error: Unable to parse date from 2024, 1, 32"; + let expected = "Execution error: Day value '32' is out of range"; assert_contains!(result.unwrap_err().to_string(), expected); Ok(()) diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/builtin_functions/function_factory.rs similarity index 96% rename from datafusion-examples/examples/function_factory.rs rename to datafusion-examples/examples/builtin_functions/function_factory.rs index d4312ae594091..3cc77371d44ce 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/builtin_functions/function_factory.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{exec_datafusion_err, exec_err, internal_err, DataFusionError}; +use datafusion::common::{DataFusionError, exec_datafusion_err, exec_err, internal_err}; use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; -use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, @@ -42,8 +44,7 @@ use std::sync::Arc; /// /// This example is rather simple and does not cover all cases required for a /// real implementation. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn function_factory() -> Result<()> { // First we must configure the SessionContext with our function factory let ctx = SessionContext::new() // register custom function factory @@ -117,10 +118,6 @@ struct ScalarFunctionWrapper { } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { &self.name } @@ -144,7 +141,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let replacement = Self::replacement(&self.expr, &args)?; diff --git a/datafusion-examples/examples/builtin_functions/main.rs b/datafusion-examples/examples/builtin_functions/main.rs new file mode 100644 index 0000000000000..42ca15f91935d --- /dev/null +++ b/datafusion-examples/examples/builtin_functions/main.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are miscellaneous function-related examples +//! +//! These examples demonstrate miscellaneous function-related features. +//! +//! ## Usage +//! ```bash +//! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `date_time` +//! (file: date_time.rs, desc: Examples of date-time related functions and queries) +//! +//! - `function_factory` +//! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) +//! +//! - `regexp` +//! (file: regexp.rs, desc: Examples of using regular expression functions) + +mod date_time; +mod function_factory; +mod regexp; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + DateTime, + FunctionFactory, + Regexp, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "builtin_functions"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DateTime => date_time::date_time().await?, + ExampleKind::FunctionFactory => function_factory::function_factory().await?, + ExampleKind::Regexp => regexp::regexp().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/builtin_functions/regexp.rs similarity index 74% rename from datafusion-examples/examples/regexp.rs rename to datafusion-examples/examples/builtin_functions/regexp.rs index 12d115b9b502c..97dc71b94e934 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/builtin_functions/regexp.rs @@ -1,5 +1,4 @@ // Licensed to the Apache Software Foundation (ASF) under one -// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -16,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::common::{assert_batches_eq, assert_contains}; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates how to use the regexp_* functions /// @@ -28,15 +30,12 @@ use datafusion::prelude::*; /// /// Supported flags can be found at /// https://docs.rs/regex/latest/regex/#grouping-and-flags -#[tokio::main] -async fn main() -> Result<()> { +pub async fn regexp() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv( - "examples", - "../../datafusion/physical-expr/tests/data/regex.csv", - CsvReadOptions::new(), - ) - .await?; + let dataset = ExampleDataset::Regex; + + ctx.register_csv("examples", dataset.path_str()?, CsvReadOptions::new()) + .await?; // // @@ -112,11 +111,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------+----------------------------------------------------+", - "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+---------------------------------------------------+----------------------------------------------------+", - "| true | true |", - "+---------------------------------------------------+----------------------------------------------------+", + "+---------------------------------------------------+----------------------------------------------------+", + "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+---------------------------------------------------+----------------------------------------------------+", + "| true | true |", + "+---------------------------------------------------+----------------------------------------------------+", ], &result ); @@ -242,11 +241,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------------------+-----------------------------------------------------+", - "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+----------------------------------------------------+-----------------------------------------------------+", - "| [John Smith] | [Smith Jones] |", - "+----------------------------------------------------+-----------------------------------------------------+", + "+----------------------------------------------------+-----------------------------------------------------+", + "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+----------------------------------------------------+-----------------------------------------------------+", + "| [John Smith] | [Smith Jones] |", + "+----------------------------------------------------+-----------------------------------------------------+", ], &result ); @@ -268,21 +267,21 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------------------------------------------------------------+", - "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", - "+---------------------------------------------------------------------------------------------------------+", - "| bbabbbc |", - "| B |", - "| aec |", - "| AbC |", - "| aBC |", - "| 4000 |", - "| xyz |", - "| München |", - "| Moscow |", - "| Koln |", - "| Today |", - "+---------------------------------------------------------------------------------------------------------+", + "+---------------------------------------------------------------------------------------------------------+", + "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", + "+---------------------------------------------------------------------------------------------------------+", + "| bbabbbc |", + "| B |", + "| aec |", + "| AbC |", + "| aBC |", + "| 4000 |", + "| xyz |", + "| München |", + "| Moscow |", + "| Koln |", + "| Today |", + "+---------------------------------------------------------------------------------------------------------+", ], &result ); @@ -296,11 +295,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------------------------------------------------------------------+", - "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", - "+------------------------------------------------------------------------+", - "| fooXarYXazY |", - "+------------------------------------------------------------------------+", + "+------------------------------------------------------------------------+", + "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", + "+------------------------------------------------------------------------+", + "| fooXarYXazY |", + "+------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs new file mode 100644 index 0000000000000..d82bd2097ce1d --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalProtoConverterExtension` +//! trait's interception methods (`execution_plan_to_proto` and +//! `proto_to_execution_plan`) to implement custom serialization logic. +//! +//! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by +//! default. This example shows how to: +//! 1. Detect plans with custom adapters during serialization +//! 2. Wrap them as Extension nodes with JSON-serialized adapter metadata +//! 3. Store the inner DataSourceExec (without adapter) as a child in the extension's inputs field +//! 4. Unwrap and restore the adapter during deserialization +//! +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the +//! power of `PhysicalProtoConverterExtension`. Both plan and expression +//! serialization route through converter hooks, enabling interception at every +//! node in the tree. + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::record_batch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::assert_batches_eq; +use datafusion::common::{Result, not_impl_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::{FileScanConfig, FileScanConfigBuilder}; +use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::TaskContext; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + PhysicalExtensionCodec, PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{ + PhysicalExprNode, PhysicalExtensionNode, PhysicalPlanNode, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; +use serde::{Deserialize, Serialize}; + +/// Example showing how to preserve custom adapter information during plan serialization. +/// +/// This demonstrates: +/// 1. Creating a custom PhysicalExprAdapter with metadata +/// 2. Using PhysicalExtensionCodec to intercept serialization +/// 3. Wrapping adapter info as Extension nodes +/// 4. Restoring adapters during deserialization +pub async fn adapter_serialization() -> Result<()> { + println!("=== PhysicalExprAdapter Serialization Example ===\n"); + + // Step 1: Create sample Parquet data in memory + println!("Step 1: Creating sample Parquet data..."); + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?; + let path = Path::from("data.parquet"); + write_parquet(&store, &path, &batch).await?; + + // Step 2: Set up session with custom adapter + println!("Step 2: Setting up session with custom adapter..."); + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::clone(&store), + ); + + // Create a table with our custom MetadataAdapterFactory + let adapter_factory = Arc::new(MetadataAdapterFactory::new("v1")); + let listing_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///data.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(logical_schema) + .with_expr_adapter_factory( + Arc::clone(&adapter_factory) as Arc + ); + let table = ListingTable::try_new(listing_config)?; + ctx.register_table("my_table", Arc::new(table))?; + + // Step 3: Create physical plan with filter + println!("Step 3: Creating physical plan with filter..."); + let df = ctx.sql("SELECT * FROM my_table WHERE id > 5").await?; + let original_plan = df.create_physical_plan().await?; + + // Verify adapter is present in original plan + let has_adapter_before = verify_adapter_in_plan(&original_plan, "original"); + println!(" Original plan has adapter: {has_adapter_before}"); + + // Step 4: Serialize with our custom codec + println!("\nStep 4: Serializing plan with AdapterPreservingCodec..."); + let codec = AdapterPreservingCodec; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&original_plan), + &codec, + &codec, + )?; + println!(" Serialized {} bytes", bytes.len()); + println!(" (DataSourceExec with adapter was wrapped as PhysicalExtensionNode)"); + + // Step 5: Deserialize with our custom codec + println!("\nStep 5: Deserializing plan with AdapterPreservingCodec..."); + let task_ctx = ctx.task_ctx(); + let restored_plan = + physical_plan_from_bytes_with_proto_converter(&bytes, &task_ctx, &codec, &codec)?; + + // Verify adapter is restored + let has_adapter_after = verify_adapter_in_plan(&restored_plan, "restored"); + println!(" Restored plan has adapter: {has_adapter_after}"); + + // Step 6: Execute and compare results + println!("\nStep 6: Executing plans and comparing results..."); + let original_results = + datafusion::physical_plan::collect(Arc::clone(&original_plan), task_ctx.clone()) + .await?; + let restored_results = + datafusion::physical_plan::collect(restored_plan, task_ctx).await?; + + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 6 |", + "| 7 |", + "| 8 |", + "| 9 |", + "| 10 |", + "+----+", + ]; + + println!("\n Original plan results:"); + arrow::util::pretty::print_batches(&original_results)?; + assert_batches_eq!(expected, &original_results); + + println!("\n Restored plan results:"); + arrow::util::pretty::print_batches(&restored_results)?; + assert_batches_eq!(expected, &restored_results); + + println!("\n=== Example Complete! ==="); + println!("Key takeaways:"); + println!( + " 1. PhysicalProtoConverterExtension provides execution_plan_to_proto/proto_to_execution_plan hooks" + ); + println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); + println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); + println!( + " 4. Both plans produce identical results despite serialization round-trip" + ); + println!(" 5. Adapters are fully preserved through the serialization round-trip"); + + Ok(()) +} + +// ============================================================================ +// MetadataAdapter - A simple custom adapter with a tag +// ============================================================================ + +/// A custom PhysicalExprAdapter that wraps another adapter. +/// The tag metadata is stored in the factory, not the adapter itself. +#[derive(Debug)] +struct MetadataAdapter { + inner: Arc, +} + +impl PhysicalExprAdapter for MetadataAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Simply delegate to inner adapter + self.inner.rewrite(expr) + } +} + +// ============================================================================ +// MetadataAdapterFactory - Factory for creating MetadataAdapter instances +// ============================================================================ + +/// Factory for creating MetadataAdapter instances. +/// The tag is stored in the factory and extracted via Debug formatting in `extract_adapter_tag`. +#[derive(Debug)] +struct MetadataAdapterFactory { + // Note: This field is read via Debug formatting in `extract_adapter_tag`. + // Rust's dead code analysis doesn't recognize Debug-based field access. + // In PR #19234, this field is used by `with_partition_values`, but that method + // doesn't exist in upstream DataFusion's PhysicalExprAdapter trait. + #[expect(dead_code)] + tag: String, +} + +impl MetadataAdapterFactory { + fn new(tag: impl Into) -> Self { + Self { tag: tag.into() } + } +} + +impl PhysicalExprAdapterFactory for MetadataAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let inner = DefaultPhysicalExprAdapterFactory + .create(logical_file_schema, physical_file_schema)?; + Ok(Arc::new(MetadataAdapter { inner })) + } +} + +// ============================================================================ +// AdapterPreservingCodec - Custom codec that preserves adapters +// ============================================================================ + +/// Extension payload structure for serializing adapter info +#[derive(Serialize, Deserialize)] +struct ExtensionPayload { + /// Marker to identify this is our custom extension + marker: String, + /// JSON-serialized adapter metadata + adapter_metadata: AdapterMetadata, +} + +/// Metadata about the adapter to recreate it during deserialization +#[derive(Serialize, Deserialize)] +struct AdapterMetadata { + /// The adapter tag (e.g., "v1") + tag: String, +} + +const EXTENSION_MARKER: &str = "adapter_preserving_extension_v1"; + +/// A codec that intercepts serialization to preserve adapter information. +#[derive(Debug)] +struct AdapterPreservingCodec; + +impl PhysicalExtensionCodec for AdapterPreservingCodec { + // Required method: decode custom extension nodes + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + // Try to parse as our extension payload + if let Ok(payload) = serde_json::from_slice::(buf) + && payload.marker == EXTENSION_MARKER + { + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Plan(format!( + "Extension node expected exactly 1 child, got {}", + inputs.len() + ))); + } + let inner_plan = inputs[0].clone(); + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + not_impl_err!("Unknown extension type") + } + + // Required method: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + // We don't need this for the example - adapter wrapping happens in + // `execution_plan_to_proto` instead. + not_impl_err!( + "try_encode not used - adapter wrapping happens in execution_plan_to_proto" + ) + } +} + +impl PhysicalProtoConverterExtension for AdapterPreservingCodec { + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + // Check if this is a DataSourceExec with adapter + if let Some(exec) = plan.downcast_ref::() + && let Some(config) = exec.data_source().downcast_ref::() + && let Some(adapter_factory) = &config.expr_adapter_factory + && let Some(tag) = extract_adapter_tag(adapter_factory.as_ref()) + { + // Try to extract our MetadataAdapterFactory's tag + println!(" [Serialize] Found DataSourceExec with adapter tag: {tag}"); + + // 1. Create adapter metadata + let adapter_metadata = AdapterMetadata { tag }; + + // 2. Serialize the inner plan to protobuf + // Note that this will drop the custom adapter since the default serialization cannot handle it + let inner_proto = PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + )?; + + // 3. Create extension payload to wrap the plan + // so that the custom adapter gets re-attached during deserialization + // The choice of JSON is arbitrary; other formats could be used. + let payload = ExtensionPayload { + marker: EXTENSION_MARKER.to_string(), + adapter_metadata, + }; + let payload_bytes = serde_json::to_vec(&payload).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to serialize payload: {e}" + )) + })?; + + // 4. Return as PhysicalExtensionNode with child plan in inputs + return Ok(PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + PhysicalExtensionNode { + node: payload_bytes, + inputs: vec![inner_proto], + }, + )), + }); + } + + // No adapter found, not a DataSourceExec, etc. - use default serialization + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // Interception point: override deserialization to unwrap adapters + fn proto_to_execution_plan( + &self, + proto: &PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + // Check if this is our custom extension wrapper + if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type + && let Ok(payload) = + serde_json::from_slice::(&extension.node) + && payload.marker == EXTENSION_MARKER + { + println!( + " [Deserialize] Found adapter extension with tag: {}", + payload.adapter_metadata.tag + ); + + // Get the inner plan proto from inputs field + if extension.inputs.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "Extension node missing child plan in inputs".to_string(), + )); + } + let inner_proto = &extension.inputs[0]; + + // Deserialize the inner plan + let inner_plan = self.default_proto_to_execution_plan(inner_proto, ctx)?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + // Not our extension - use default deserialization + self.default_proto_to_execution_plan(proto, ctx) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + parse_physical_expr_with_converter(proto, input_schema, ctx, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Write a RecordBatch to Parquet in the object store +async fn write_parquet( + store: &dyn ObjectStore, + path: &Path, + batch: &arrow::record_batch::RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Extract the tag from a MetadataAdapterFactory. +/// +/// Note: Since `PhysicalExprAdapterFactory` doesn't provide `as_any()` for downcasting, +/// we parse the Debug output. In a production system, you might add a dedicated trait +/// method for metadata extraction. +fn extract_adapter_tag(factory: &dyn PhysicalExprAdapterFactory) -> Option { + let debug_str = format!("{factory:?}"); + if debug_str.contains("MetadataAdapterFactory") { + // Extract tag from debug output: MetadataAdapterFactory { tag: "v1" } + if let Some(start) = debug_str.find("tag: \"") { + let after_tag = &debug_str[start + 6..]; + if let Some(end) = after_tag.find('"') { + return Some(after_tag[..end].to_string()); + } + } + } + None +} + +/// Create an adapter factory from a tag +fn create_adapter_factory(tag: &str) -> Arc { + Arc::new(MetadataAdapterFactory::new(tag)) +} + +/// Inject an adapter into a plan (assumes plan is a DataSourceExec with FileScanConfig) +fn inject_adapter_into_plan( + plan: Arc, + adapter_factory: Arc, +) -> Result> { + if let Some(exec) = plan.downcast_ref::() + && let Some(config) = exec.data_source().downcast_ref::() + { + let new_config = FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(Some(adapter_factory)) + .build(); + return Ok(DataSourceExec::from_data_source(new_config)); + } + // If not a DataSourceExec with FileScanConfig, return as-is + Ok(plan) +} + +/// Helper to verify if a plan has an adapter (for testing/validation) +fn verify_adapter_in_plan(plan: &Arc, label: &str) -> bool { + // Walk the plan tree to find DataSourceExec with adapter + fn check_plan(plan: &dyn ExecutionPlan) -> bool { + if let Some(exec) = plan.downcast_ref::() + && let Some(config) = exec.data_source().downcast_ref::() + && config.expr_adapter_factory.is_some() + { + return true; + } + // Check children + for child in plan.children() { + if check_plan(child.as_ref()) { + return true; + } + } + false + } + + let has_adapter = check_plan(plan.as_ref()); + println!(" [Verify] {label} plan adapter check: {has_adapter}"); + has_adapter +} diff --git a/datafusion-examples/examples/csv_json_opener.rs b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs similarity index 60% rename from datafusion-examples/examples/csv_json_opener.rs rename to datafusion-examples/examples/custom_data_source/csv_json_opener.rs index ef2a3eaca0c88..51c0e2167053e 100644 --- a/datafusion-examples/examples/csv_json_opener.rs +++ b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs @@ -15,34 +15,36 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::config::CsvOptions; use datafusion::{ assert_batches_eq, datasource::{ file_format::file_compression_type::FileCompressionType, listing::PartitionedFile, object_store::ObjectStoreUrl, - physical_plan::{CsvSource, FileSource, FileStream, JsonOpener, JsonSource}, + physical_plan::{ + CsvSource, FileSource, FileStreamBuilder, JsonOpener, JsonSource, + }, }, error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, - test_util::aggr_test_schema, }; -use datafusion::datasource::{ - physical_plan::FileScanConfigBuilder, table_schema::TableSchema, -}; +use datafusion::datasource::physical_plan::FileScanConfigBuilder; +use datafusion_examples::utils::datasets::ExampleDataset; use futures::StreamExt; -use object_store::{local::LocalFileSystem, memory::InMemory, ObjectStore}; +use object_store::{ObjectStoreExt, local::LocalFileSystem, memory::InMemory}; /// This example demonstrates using the low level [`FileStream`] / [`FileOpener`] APIs to directly /// read data from (CSV/JSON) into Arrow RecordBatches. /// /// If you want to query data in CSV or JSON files, see the [`dataframe.rs`] and [`sql_query.rs`] examples -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_json_opener() -> Result<()> { csv_opener().await?; json_opener().await?; Ok(()) @@ -50,48 +52,56 @@ async fn main() -> Result<()> { async fn csv_opener() -> Result<()> { let object_store = Arc::new(LocalFileSystem::new()); - let schema = aggr_test_schema(); - let testdata = datafusion::test_util::arrow_test_data(); - let path = format!("{testdata}/csv/aggregate_test_100.csv"); + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); + let schema = dataset.schema(); - let path = std::path::Path::new(&path).canonicalize()?; + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; - let scan_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - Arc::clone(&schema), - Arc::new(CsvSource::default()), - ) - .with_projection_indices(Some(vec![12, 0])) - .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.display().to_string(), 10)) - .build(); - - let config = CsvSource::new(true, b',', b'"') + let source = CsvSource::new(Arc::clone(&schema)) + .with_csv_options(options) .with_comment(Some(b'#')) - .with_schema(TableSchema::from_file_schema(schema)) - .with_batch_size(8192) - .with_projection(&scan_config); + .with_batch_size(8192); + + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_projection_indices(Some(vec![0, 1]))? + .with_limit(Some(5)) + .with_file(PartitionedFile::new(csv_path.display().to_string(), 10)) + .build(); - let opener = config.create_file_opener(object_store, &scan_config, 0); + let opener = + scan_config + .file_source() + .create_file_opener(object_store, &scan_config, 0)?; let mut result = vec![]; - let mut stream = - FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new())?; + let metrics = ExecutionPlanMetricsSet::new(); + let mut stream = FileStreamBuilder::new(&scan_config) + .with_partition(0) + .with_file_opener(opener) + .with_metrics(&metrics) + .build()?; while let Some(batch) = stream.next().await.transpose()? { result.push(batch); } assert_batches_eq!( &[ - "+--------------------------------+----+", - "| c13 | c1 |", - "+--------------------------------+----+", - "| 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | c |", - "| C2GT5KVyOPZpgKVl110TyZO0NcJ434 | d |", - "| AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | b |", - "| 0keZ5G8BffGwgF2RwQD59TFzMStxCB | a |", - "| Ig1QcuKsjHXkproePdERo2w0mYzIqd | b |", - "+--------------------------------+----+", + "+-----+-------+", + "| car | speed |", + "+-----+-------+", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "| red | 19.0 |", + "+-----+-------+", ], &result ); @@ -121,24 +131,24 @@ async fn json_opener() -> Result<()> { projected, FileCompressionType::UNCOMPRESSED, Arc::new(object_store), + true, ); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - schema, - Arc::new(JsonSource::default()), + Arc::new(JsonSource::new(schema)), ) - .with_projection_indices(Some(vec![1, 0])) + .with_projection_indices(Some(vec![1, 0]))? .with_limit(Some(5)) .with_file(PartitionedFile::new(path.to_string(), 10)) .build(); - let mut stream = FileStream::new( - &scan_config, - 0, - Arc::new(opener), - &ExecutionPlanMetricsSet::new(), - )?; + let metrics = ExecutionPlanMetricsSet::new(); + let mut stream = FileStreamBuilder::new(&scan_config) + .with_partition(0) + .with_file_opener(Arc::new(opener)) + .with_metrics(&metrics) + .build()?; let mut result = vec![]; while let Some(batch) = stream.next().await.transpose()? { result.push(batch); diff --git a/datafusion-examples/examples/csv_sql_streaming.rs b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs similarity index 82% rename from datafusion-examples/examples/csv_sql_streaming.rs rename to datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs index 99264bbcb486d..4692086a10b26 100644 --- a/datafusion-examples/examples/csv_sql_streaming.rs +++ b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs @@ -15,44 +15,46 @@ // specific language governing permissions and limitations // under the License. -use datafusion::common::test_util::datafusion_test_data; +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates executing a simple query against an Arrow data source (CSV) and /// fetching results with streaming aggregation and streaming window -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_sql_streaming() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion_test_data(); + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); - // Register a table source and tell DataFusion the file is ordered by `ts ASC`. + // Register a table source and tell DataFusion the file is ordered by `car ASC`. // Note it is the responsibility of the user to make sure // that file indeed satisfies this condition or else incorrect answers may be produced. let asc = true; let nulls_first = true; - let sort_expr = vec![col("ts").sort(asc, nulls_first)]; + let sort_expr = vec![col("car").sort(asc, nulls_first)]; // register csv file with the execution context ctx.register_csv( "ordered_table", - &format!("{testdata}/window_1.csv"), + csv_path.to_str().unwrap(), CsvReadOptions::new().file_sort_order(vec![sort_expr]), ) .await?; // execute the query - // Following query can be executed with unbounded sources because group by expressions (e.g ts) is + // Following query can be executed with unbounded sources because group by expressions (e.g car) is // already ordered at the source. // // Unbounded sources means that if the input came from a "never ending" source (such as a FIFO // file on unix) the query could produce results incrementally as data was read. let df = ctx .sql( - "SELECT ts, MIN(inc_col), MAX(inc_col) \ + "SELECT car, MIN(speed), MAX(speed) \ FROM ordered_table \ - GROUP BY ts", + GROUP BY car", ) .await?; @@ -63,7 +65,7 @@ async fn main() -> Result<()> { // its result in streaming fashion, because its required ordering is already satisfied at the source. let df = ctx .sql( - "SELECT ts, SUM(inc_col) OVER(ORDER BY ts ASC) \ + "SELECT car, SUM(speed) OVER(ORDER BY car ASC) \ FROM ordered_table", ) .await?; diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_data_source/custom_datasource.rs similarity index 87% rename from datafusion-examples/examples/custom_datasource.rs rename to datafusion-examples/examples/custom_data_source/custom_datasource.rs index bc865fac5a338..701a886d2a140 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_data_source/custom_datasource.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +//! See `main.rs` for how to run it. + use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Debug, Formatter}; use std::sync::{Arc, Mutex}; use std::time::Duration; use async_trait::async_trait; -use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +use datafusion::arrow::array::{UInt8Builder, UInt64Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::{provider_as_source, TableProvider, TableType}; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::datasource::{TableProvider, TableType, provider_as_source}; use datafusion::error::Result; use datafusion::execution::context::TaskContext; use datafusion::logical_expr::LogicalPlanBuilder; @@ -33,8 +35,8 @@ use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ - project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, project_schema, }; use datafusion::prelude::*; @@ -42,8 +44,7 @@ use datafusion::catalog::Session; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource -#[tokio::main] -async fn main() -> Result<()> { +pub async fn custom_datasource() -> Result<()> { // create our custom datasource and adding some users let db = CustomDataSource::default(); db.populate_users(); @@ -160,10 +161,6 @@ impl Default for CustomDataSource { #[async_trait] impl TableProvider for CustomDataSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { SchemaRef::new(Schema::new(vec![ Field::new("id", DataType::UInt8, false), @@ -191,10 +188,11 @@ impl TableProvider for CustomDataSource { struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, - cache: PlanProperties, + cache: Arc, } impl CustomExec { + #[expect(clippy::needless_pass_by_value)] fn new( projections: Option<&Vec>, schema: SchemaRef, @@ -205,7 +203,7 @@ impl CustomExec { Self { db, projected_schema, - cache, + cache: Arc::new(cache), } } @@ -232,11 +230,7 @@ impl ExecutionPlan for CustomExec { "CustomExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -281,4 +275,20 @@ impl ExecutionPlan for CustomExec { None, )?)) } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs similarity index 77% rename from datafusion-examples/examples/custom_file_casts.rs rename to datafusion-examples/examples/custom_data_source/custom_file_casts.rs index 4d97ecd91dc64..71addc6d1bcb0 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -15,43 +15,44 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; -use arrow::array::{record_batch, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::array::{RecordBatch, record_batch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::assert_batches_eq; +use datafusion::common::Result; use datafusion::common::not_impl_err; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; -use datafusion::physical_expr::expressions::CastExpr; use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::expressions::CastExpr; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, }; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; // Example showing how to implement custom casting rules to adapt file schemas. -// This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error -// before even reading the data. -// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast. - -#[tokio::main] -async fn main() -> Result<()> { +// This example enforces strictly widening casts: if the file type is Int64 and +// the table type is Int32, it errors before reading the data. Without this +// custom cast rule DataFusion would apply the narrowing cast and might only +// error after reading a row that it could not cast. +pub async fn custom_file_casts() -> Result<()> { println!("=== Creating example data ==="); - // Create a logical / table schema with an Int32 column + // Create a logical / table schema with an Int32 column (nullable) let logical_schema = - Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)])); // Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing) let store = Arc::new(InMemory::new()) as Arc; @@ -139,7 +140,7 @@ async fn write_data( Ok(()) } -/// Factory for creating DefaultValuePhysicalExprAdapter instances +/// Factory for creating custom cast physical expression adapters #[derive(Debug)] struct CustomCastPhysicalExprAdapterFactory { inner: Arc, @@ -156,19 +157,19 @@ impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let inner = self .inner - .create(logical_file_schema, Arc::clone(&physical_file_schema)); - Arc::new(CustomCastsPhysicalExprAdapter { + .create(logical_file_schema, Arc::clone(&physical_file_schema))?; + Ok(Arc::new(CustomCastsPhysicalExprAdapter { physical_file_schema, inner, - }) + })) } } -/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata -/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +/// Custom `PhysicalExprAdapter` that wraps the default adapter and rejects +/// narrowing file-schema casts. #[derive(Debug, Clone)] struct CustomCastsPhysicalExprAdapter { physical_file_schema: SchemaRef, @@ -177,15 +178,17 @@ struct CustomCastsPhysicalExprAdapter { impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { - // First delegate to the inner adapter to handle missing columns and discover any necessary casts + // First delegate to the inner adapter to handle standard schema adaptation + // and discover any necessary casts. expr = self.inner.rewrite(expr)?; - // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression - // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). + // Now apply custom casting rules or swap CastExprs for a custom cast + // kernel / expression. For example, DataFusion Comet has a custom cast + // kernel in its native Spark expression implementation. expr.transform(|expr| { - if let Some(cast) = expr.as_any().downcast_ref::() { + if let Some(cast) = expr.downcast_ref::() { let input_data_type = cast.expr().data_type(&self.physical_file_schema)?; - let output_data_type = cast.data_type(&self.physical_file_schema)?; + let output_data_type = cast.target_field().data_type(); if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" @@ -196,14 +199,4 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { }) .data() } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(Self { - inner: self.inner.with_partition_values(partition_values), - ..self.clone() - }) - } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_data_source/custom_file_format.rs similarity index 92% rename from datafusion-examples/examples/custom_file_format.rs rename to datafusion-examples/examples/custom_data_source/custom_file_format.rs index 67fe642fd46ee..0cfbe11877e4d 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_format.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +//! See `main.rs` for how to run it. + +use std::sync::Arc; use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, @@ -25,12 +27,13 @@ use datafusion::{ catalog::Session, common::{GetExt, Statistics}, datasource::{ + MemTable, file_format::{ - csv::CsvFormatFactory, file_compression_type::FileCompressionType, - FileFormat, FileFormatFactory, + FileFormat, FileFormatFactory, csv::CsvFormatFactory, + file_compression_type::FileCompressionType, }, physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, - MemTable, + table_schema::TableSchema, }, error::Result, execution::session_state::SessionStateBuilder, @@ -47,6 +50,42 @@ use tempfile::tempdir; /// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. /// The former, once registered with the SessionState, will then be used /// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. +pub async fn custom_file_format() -> Result<()> { + // Create a new context with the default configuration + let mut state = SessionStateBuilder::new().with_default_features().build(); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true)?; + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table)?; + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} #[derive(Debug)] /// Custom file format that reads and writes TSV files @@ -65,10 +104,6 @@ impl TSVFileFormat { #[async_trait::async_trait] impl FileFormat for TSVFileFormat { - fn as_any(&self) -> &dyn Any { - self - } - fn get_ext(&self) -> String { "tsv".to_string() } @@ -128,8 +163,8 @@ impl FileFormat for TSVFileFormat { .await } - fn file_source(&self) -> Arc { - self.csv_file_format.file_source() + fn file_source(&self, table_schema: TableSchema) -> Arc { + self.csv_file_format.file_source(table_schema) } } @@ -168,10 +203,6 @@ impl FileFormatFactory for TSVFileFactory { fn default(&self) -> Arc { todo!() } - - fn as_any(&self) -> &dyn Any { - self - } } impl GetExt for TSVFileFactory { @@ -180,44 +211,6 @@ impl GetExt for TSVFileFactory { } } -#[tokio::main] -async fn main() -> Result<()> { - // Create a new context with the default configuration - let mut state = SessionStateBuilder::new().with_default_features().build(); - - // Register the custom file format - let file_format = Arc::new(TSVFileFactory::new()); - state.register_file_format(file_format, true).unwrap(); - - // Create a new context with the custom file format - let ctx = SessionContext::new_with_state(state); - - let mem_table = create_mem_table(); - ctx.register_table("mem_table", mem_table).unwrap(); - - let temp_dir = tempdir().unwrap(); - let table_save_path = temp_dir.path().join("mem_table.tsv"); - - let d = ctx - .sql(&format!( - "COPY mem_table TO '{}' STORED AS TSV;", - table_save_path.display(), - )) - .await?; - - let results = d.collect().await?; - println!( - "Number of inserted rows: {:?}", - (results[0] - .column_by_name("count") - .unwrap() - .as_primitive::() - .value(0)) - ); - - Ok(()) -} - // create a simple mem table fn create_mem_table() -> Arc { let fields = vec![ diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/custom_data_source/default_column_values.rs similarity index 61% rename from datafusion-examples/examples/default_column_values.rs rename to datafusion-examples/examples/custom_data_source/default_column_values.rs index d3a7d2ec67f3c..633b98244367e 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/custom_data_source/default_column_values.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::DFSchema; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; @@ -37,40 +37,37 @@ use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::{lit, SessionConfig}; +use datafusion::prelude::{SessionConfig, lit}; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, }; use futures::StreamExt; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; // Metadata key for storing default values in field metadata const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; -// Example showing how to implement custom default value handling for missing columns -// using field metadata and PhysicalExprAdapter. -// -// This example demonstrates how to: -// 1. Store default values in field metadata using a constant key -// 2. Create a custom PhysicalExprAdapter that reads these defaults -// 3. Inject default values for missing columns in filter predicates -// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation -// 5. Wrap string default values in cast expressions for proper type conversion -// -// Important: PhysicalExprAdapter is specifically designed for rewriting filter predicates -// that get pushed down to file scans. For handling missing columns in projections, -// other mechanisms in DataFusion are used (like SchemaAdapter). -// -// The metadata-based approach provides a flexible way to store default values as strings -// and cast them to the appropriate types at query time. - -#[tokio::main] -async fn main() -> Result<()> { +/// Example showing how to implement custom default value handling for missing columns +/// using field metadata and PhysicalExprAdapter. +/// +/// This example demonstrates how to: +/// 1. Store default values in field metadata using a constant key +/// 2. Create a custom PhysicalExprAdapter that reads these defaults +/// 3. Inject default values for missing columns in filter predicates using `replace_columns_with_literals` +/// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation +/// 5. Convert string default values to proper types using `ScalarValue::cast_to()` at planning time +/// +/// Important: PhysicalExprAdapter handles rewriting both filter predicates and projection +/// expressions for file scans, including handling missing columns. +/// +/// The metadata-based approach provides a flexible way to store default values as strings +/// and cast them to the appropriate types at planning time, avoiding runtime overhead. +pub async fn default_column_values() -> Result<()> { println!("=== Creating example data with missing columns and default values ==="); // Create sample data where the logical schema has more columns than the physical schema @@ -81,15 +78,14 @@ async fn main() -> Result<()> { let mut buf = vec![]; let props = WriterProperties::builder() - .set_max_row_group_size(2) + .set_max_row_group_row_count(Some(2)) .build(); let mut writer = - ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props)) - .expect("creating writer"); + ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props))?; - writer.write(&batch).expect("Writing batch"); - writer.close().unwrap(); + writer.write(&batch)?; + writer.close()?; buf }; let path = Path::from("example.parquet"); @@ -138,12 +134,14 @@ async fn main() -> Result<()> { println!("\n=== Key Insight ==="); println!("This example demonstrates how PhysicalExprAdapter works:"); println!("1. Physical schema only has 'id' and 'name' columns"); - println!("2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults"); - println!("3. Our custom adapter intercepts filter expressions on missing columns"); - println!("4. Default values from metadata are injected as cast expressions"); + println!( + "2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults" + ); + println!( + "3. Our custom adapter uses replace_columns_with_literals to inject default values" + ); + println!("4. Default values from metadata are cast to proper types at planning time"); println!("5. The DefaultPhysicalExprAdapter handles other schema adaptations"); - println!("\nNote: PhysicalExprAdapter is specifically for filter predicates."); - println!("For projection columns, different mechanisms handle missing columns."); Ok(()) } @@ -202,12 +200,8 @@ impl DefaultValueTableProvider { #[async_trait] impl TableProvider for DefaultValueTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -228,14 +222,14 @@ impl TableProvider for DefaultValueTableProvider { filters: &[Expr], limit: Option, ) -> Result> { - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); let df_schema = DFSchema::try_from(schema.clone())?; let filter = state.create_physical_expr( conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), &df_schema, )?; - let parquet_source = ParquetSource::default() + let parquet_source = ParquetSource::new(schema.clone()) .with_predicate(filter) .with_pushdown_filters(true); @@ -257,10 +251,9 @@ impl TableProvider for DefaultValueTableProvider { let file_scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("memory://")?, - self.schema.clone(), Arc::new(parquet_source), ) - .with_projection_indices(projection.cloned()) + .with_projection_indices(projection.cloned())? .with_limit(limit) .with_file_group(file_group) .with_expr_adapter(Some(Arc::new(DefaultValuePhysicalExprAdapterFactory) as _)); @@ -280,17 +273,18 @@ impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory - .create(logical_file_schema.clone(), physical_file_schema.clone()); + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + )?; - Arc::new(DefaultValuePhysicalExprAdapter { + Ok(Arc::new(DefaultValuePhysicalExprAdapter { logical_file_schema, physical_file_schema, default_adapter, - partition_values: Vec::new(), - }) + })) } } @@ -301,98 +295,36 @@ struct DefaultValuePhysicalExprAdapter { logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, default_adapter: Arc, - partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { - // First try our custom default value injection for missing columns - let rewritten = expr - .transform(|expr| { - self.inject_default_values( - expr, - &self.logical_file_schema, - &self.physical_file_schema, - ) - }) - .data()?; - - // Then apply the default adapter as a fallback to handle standard schema differences - // like type casting, partition column handling, etc. - let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter - .with_partition_values(self.partition_values.clone()) - } else { - self.default_adapter.clone() - }; - - default_adapter.rewrite(rewritten) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(DefaultValuePhysicalExprAdapter { - logical_file_schema: self.logical_file_schema.clone(), - physical_file_schema: self.physical_file_schema.clone(), - default_adapter: self.default_adapter.clone(), - partition_values, - }) - } -} - -impl DefaultValuePhysicalExprAdapter { - fn inject_default_values( - &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - ) -> Result>> { - if let Some(column) = expr.as_any().downcast_ref::() { - let column_name = column.name(); - - // Check if this column exists in the physical schema - if physical_file_schema.index_of(column_name).is_err() { - // Column is missing from physical schema, check if logical schema has a default - if let Ok(logical_field) = - logical_file_schema.field_with_name(column_name) - { - if let Some(default_value_str) = - logical_field.metadata().get(DEFAULT_VALUE_METADATA_KEY) - { - // Create a string literal and wrap it in a cast expression - let default_literal = self.create_default_value_expr( - default_value_str, - logical_field.data_type(), - )?; - return Ok(Transformed::yes(default_literal)); - } - } + // Pre-compute replacements for missing columns with default values + let mut replacements = HashMap::new(); + for field in self.logical_file_schema.fields() { + // Skip columns that exist in physical schema + if self.physical_file_schema.index_of(field.name()).is_ok() { + continue; } - } - - // No transformation needed - Ok(Transformed::no(expr)) - } - fn create_default_value_expr( - &self, - value_str: &str, - data_type: &DataType, - ) -> Result> { - // Create a string literal with the default value - let string_literal = - Arc::new(Literal::new(ScalarValue::Utf8(Some(value_str.to_string())))); - - // If the target type is already Utf8, return the string literal directly - if matches!(data_type, DataType::Utf8) { - return Ok(string_literal); + // Check if this missing column has a default value in metadata + if let Some(default_str) = field.metadata().get(DEFAULT_VALUE_METADATA_KEY) { + // Create a Utf8 ScalarValue from the string and cast it to the target type + let string_value = ScalarValue::Utf8(Some(default_str.to_string())); + let typed_value = string_value.cast_to(field.data_type())?; + replacements.insert(field.name().as_str(), typed_value); + } } - // Otherwise, wrap the string literal in a cast expression - let cast_expr = Arc::new(CastExpr::new(string_literal, data_type.clone(), None)); + // Replace columns with their default literals if any + let rewritten = if !replacements.is_empty() { + let refs: HashMap<_, _> = replacements.iter().map(|(k, v)| (*k, v)).collect(); + replace_columns_with_literals(expr, &refs)? + } else { + expr + }; - Ok(cast_expr) + // Apply the default adapter as a fallback for other schema adaptations + self.default_adapter.rewrite(rewritten) } } diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs similarity index 90% rename from datafusion-examples/examples/file_stream_provider.rs rename to datafusion-examples/examples/custom_data_source/file_stream_provider.rs index e6c59d57e98de..5b43072d43f80 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs @@ -15,6 +15,31 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +/// Demonstrates how to use [`FileStreamProvider`] and [`StreamTable`] to stream data +/// from a file-like source (FIFO) into DataFusion for continuous querying. +/// +/// On non-Windows systems, this example creates a named pipe (FIFO) and +/// writes rows into it asynchronously while DataFusion reads the data +/// through a `FileStreamProvider`. +/// +/// This illustrates how to integrate dynamically updated data sources +/// with DataFusion without needing to reload the entire dataset each time. +/// +/// This example does not work on Windows. +pub async fn file_stream_provider() -> datafusion::error::Result<()> { + #[cfg(target_os = "windows")] + { + println!("file_stream_provider example does not work on windows"); + Ok(()) + } + #[cfg(not(target_os = "windows"))] + { + non_windows::main().await + } +} + #[cfg(not(target_os = "windows"))] mod non_windows { use datafusion::assert_batches_eq; @@ -22,8 +47,8 @@ mod non_windows { use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -34,9 +59,9 @@ mod non_windows { use tempfile::TempDir; use tokio::task::JoinSet; - use datafusion::common::{exec_err, Result}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; + use datafusion::common::{Result, exec_err}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::logical_expr::SortExpr; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -101,7 +126,6 @@ mod non_windows { let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path; // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tasks.spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test @@ -186,16 +210,3 @@ mod non_windows { Ok(()) } } - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - #[cfg(target_os = "windows")] - { - println!("file_stream_provider example does not work on windows"); - Ok(()) - } - #[cfg(not(target_os = "windows"))] - { - non_windows::main().await - } -} diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs new file mode 100644 index 0000000000000..40409d3690d3a --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples are all related to extending or defining how DataFusion reads data +//! +//! These examples demonstrate how DataFusion reads data. +//! +//! ## Usage +//! ```bash +//! cargo run --example custom_data_source -- [all|adapter_serialization|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `adapter_serialization` +//! (file: adapter_serialization.rs, desc: Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception) +//! +//! - `csv_json_opener` +//! (file: csv_json_opener.rs, desc: Use low-level FileOpener APIs for CSV/JSON) +//! +//! - `csv_sql_streaming` +//! (file: csv_sql_streaming.rs, desc: Run a streaming SQL query against CSV data) +//! +//! - `custom_datasource` +//! (file: custom_datasource.rs, desc: Query a custom TableProvider) +//! +//! - `custom_file_casts` +//! (file: custom_file_casts.rs, desc: Implement custom casting rules) +//! +//! - `custom_file_format` +//! (file: custom_file_format.rs, desc: Write to a custom file format) +//! +//! - `default_column_values` +//! (file: default_column_values.rs, desc: Custom default values using metadata) +//! +//! - `file_stream_provider` +//! (file: file_stream_provider.rs, desc: Read/write via FileStreamProvider for streams) + +mod adapter_serialization; +mod csv_json_opener; +mod csv_sql_streaming; +mod custom_datasource; +mod custom_file_casts; +mod custom_file_format; +mod default_column_values; +mod file_stream_provider; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AdapterSerialization, + CsvJsonOpener, + CsvSqlStreaming, + CustomDatasource, + CustomFileCasts, + CustomFileFormat, + DefaultColumnValues, + FileStreamProvider, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "custom_data_source"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AdapterSerialization => { + adapter_serialization::adapter_serialization().await? + } + ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, + ExampleKind::CsvSqlStreaming => { + csv_sql_streaming::csv_sql_streaming().await? + } + ExampleKind::CustomDatasource => { + custom_datasource::custom_datasource().await? + } + ExampleKind::CustomFileCasts => { + custom_file_casts::custom_file_casts().await? + } + ExampleKind::CustomFileFormat => { + custom_file_format::custom_file_format().await? + } + ExampleKind::DefaultColumnValues => { + default_column_values::default_column_values().await? + } + ExampleKind::FileStreamProvider => { + file_stream_provider::file_stream_provider().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/data_io/catalog.rs similarity index 95% rename from datafusion-examples/examples/catalog.rs rename to datafusion-examples/examples/data_io/catalog.rs index 229867cdfc5bb..7e5cc5a4cfc05 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/data_io/catalog.rs @@ -15,27 +15,29 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! Simple example of a catalog/schema implementation. use async_trait::async_trait; use datafusion::{ arrow::util::pretty, catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}, datasource::{ - file_format::{csv::CsvFormat, FileFormat}, - listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, TableProvider, + file_format::{FileFormat, csv::CsvFormat}, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, }, error::Result, execution::context::SessionState, prelude::SessionContext, }; use std::sync::RwLock; -use std::{any::Any, collections::HashMap, path::Path, sync::Arc}; +use std::{collections::HashMap, path::Path, sync::Arc}; use std::{fs::File, io::Write}; use tempfile::TempDir; -#[tokio::main] -async fn main() -> Result<()> { +/// Register the table into a custom catalog +pub async fn catalog() -> Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Info) .init(); @@ -134,12 +136,13 @@ struct DirSchemaOpts<'a> { dir: &'a Path, format: Arc, } + /// Schema where every file with extension `ext` in a given `dir` is a table. #[derive(Debug)] struct DirSchema { - ext: String, tables: RwLock>>, } + impl DirSchema { async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result> { let DirSchemaOpts { ext, dir, format } = opts; @@ -169,21 +172,12 @@ impl DirSchema { } Ok(Arc::new(Self { tables: RwLock::new(tables), - ext: ext.to_string(), })) } - #[allow(unused)] - fn name(&self) -> &str { - &self.ext - } } #[async_trait] impl SchemaProvider for DirSchema { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { let tables = self.tables.read().unwrap(); tables.keys().cloned().collect::>() @@ -198,6 +192,7 @@ impl SchemaProvider for DirSchema { let tables = self.tables.read().unwrap(); tables.contains_key(name) } + fn register_table( &self, name: String, @@ -211,7 +206,6 @@ impl SchemaProvider for DirSchema { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). - #[allow(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { let mut tables = self.tables.write().unwrap(); log::info!("dropping table {name}"); @@ -223,6 +217,7 @@ impl SchemaProvider for DirSchema { struct DirCatalog { schemas: RwLock>>, } + impl DirCatalog { fn new() -> Self { Self { @@ -230,10 +225,8 @@ impl DirCatalog { } } } + impl CatalogProvider for DirCatalog { - fn as_any(&self) -> &dyn Any { - self - } fn register_schema( &self, name: &str, @@ -260,11 +253,13 @@ impl CatalogProvider for DirCatalog { } } } + /// Catalog lists holds multiple catalog providers. Each context has a single catalog list. #[derive(Debug)] struct CustomCatalogProviderList { catalogs: RwLock>>, } + impl CustomCatalogProviderList { fn new() -> Self { Self { @@ -272,10 +267,8 @@ impl CustomCatalogProviderList { } } } + impl CatalogProviderList for CustomCatalogProviderList { - fn as_any(&self) -> &dyn Any { - self - } fn register_catalog( &self, name: String, diff --git a/datafusion-examples/examples/data_io/in_memory_object_store.rs b/datafusion-examples/examples/data_io/in_memory_object_store.rs new file mode 100644 index 0000000000000..9a308f06c5abd --- /dev/null +++ b/datafusion-examples/examples/data_io/in_memory_object_store.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This follows the recommended approach: implement the `ObjectStore` trait +//! (or use an existing implementation), register it with DataFusion, and then +//! read a URL "path" from that store. +//! See the in-memory reference implementation: +//! https://docs.rs/object_store/latest/object_store/memory/struct.InMemory.html + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; + +/// Demonstrates reading CSV data from an in-memory object store. +/// +/// The same pattern applies to JSON/Parquet: register a store for a URL +/// prefix, write bytes into the store, then read via that URL. +pub async fn in_memory_object_store() -> Result<()> { + let store: Arc = Arc::new(InMemory::new()); + let ctx = SessionContext::new(); + let object_store_url = ObjectStoreUrl::parse("memory://")?; + // Register a URL prefix to route reads through this object store. + ctx.register_object_store(object_store_url.as_ref(), Arc::clone(&store)); + + let schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ]); + + println!("=== CSV from memory ==="); + let csv_path = Path::from("/people.csv"); + let csv_data = b"id,name\n1,Alice\n2,Bob\n"; + // Write bytes into the in-memory object store. + store + .put(&csv_path, PutPayload::from_static(csv_data)) + .await?; + // Read using the URL that matches the registered prefix. + let csv = ctx + .read_csv( + "memory:///people.csv", + CsvReadOptions::new().schema(&schema), + ) + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = [ + "+----+-------+", + "| id | name |", + "+----+-------+", + "| 1 | Alice |", + "| 2 | Bob |", + "+----+-------+", + ]; + assert_batches_eq!(expected, &csv); + + Ok(()) +} diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/data_io/json_shredding.rs similarity index 74% rename from datafusion-examples/examples/json_shredding.rs rename to datafusion-examples/examples/data_io/json_shredding.rs index 5ef8b59b64200..72fbb56773123 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/data_io/json_shredding.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::common::{Result, assert_contains, exec_datafusion_err}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -37,7 +38,7 @@ use datafusion::logical_expr::{ use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_expr::{ScalarFunctionExpr, expressions}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -45,7 +46,7 @@ use datafusion_physical_expr_adapter::{ }; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStoreExt, PutPayload}; // Example showing how to implement custom filter rewriting for JSON shredding. // @@ -63,8 +64,7 @@ use object_store::{ObjectStore, PutPayload}; // 1. Push down predicates for better filtering // 2. Avoid expensive JSON parsing at query time // 3. Leverage columnar storage benefits for the materialized fields -#[tokio::main] -async fn main() -> Result<()> { +pub async fn json_shredding() -> Result<()> { println!("=== Creating example data with flat columns and underscore prefixes ==="); // Create sample data with flat columns using underscore prefixes @@ -75,7 +75,7 @@ async fn main() -> Result<()> { let mut buf = vec![]; let props = WriterProperties::builder() - .set_max_row_group_size(2) + .set_max_row_group_row_count(Some(2)) .build(); let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) @@ -206,10 +206,6 @@ impl Default for JsonGetStr { } impl ScalarUDFImpl for JsonGetStr { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "json_get_str" } @@ -232,7 +228,7 @@ impl ScalarUDFImpl for JsonGetStr { _ => { return Err(exec_datafusion_err!( "json_get_str first argument must be a string" - )) + )); } }; // We expect a string array that contains JSON strings @@ -248,7 +244,7 @@ impl ScalarUDFImpl for JsonGetStr { _ => { return Err(exec_datafusion_err!( "json_get_str second argument must be a string array" - )) + )); } }; let values = json_array @@ -274,17 +270,17 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { &self, logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, - ) -> Arc { + ) -> Result> { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory - .create(logical_file_schema.clone(), physical_file_schema.clone()); + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + )?; - Arc::new(ShreddedJsonRewriter { - logical_file_schema, + Ok(Arc::new(ShreddedJsonRewriter { physical_file_schema, default_adapter, - partition_values: Vec::new(), - }) + })) } } @@ -292,10 +288,8 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { /// and wraps DefaultPhysicalExprAdapter for standard schema adaptation #[derive(Debug)] struct ShreddedJsonRewriter { - logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, default_adapter: Arc, - partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for ShreddedJsonRewriter { @@ -306,27 +300,8 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { .data()?; // Then apply the default adapter as a fallback to handle standard schema differences - // like type casting, missing columns, and partition column handling - let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter - .with_partition_values(self.partition_values.clone()) - } else { - self.default_adapter.clone() - }; - - default_adapter.rewrite(rewritten) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(ShreddedJsonRewriter { - logical_file_schema: self.logical_file_schema.clone(), - physical_file_schema: self.physical_file_schema.clone(), - default_adapter: self.default_adapter.clone(), - partition_values, - }) + // like type casting and missing columns + self.default_adapter.rewrite(rewritten) } } @@ -336,44 +311,39 @@ impl ShreddedJsonRewriter { expr: Arc, physical_file_schema: &Schema, ) -> Result>> { - if let Some(func) = expr.as_any().downcast_ref::() { - if func.name() == "json_get_str" && func.args().len() == 2 { - // Get the key from the first argument - if let Some(literal) = func.args()[0] - .as_any() - .downcast_ref::() + if let Some(func) = expr.downcast_ref::() + && func.name() == "json_get_str" + && func.args().len() == 2 + { + // Get the key from the first argument + if let Some(literal) = func.args()[0].downcast_ref::() + && let ScalarValue::Utf8(Some(field_name)) = literal.value() + { + // Get the column from the second argument + if let Some(column) = func.args()[1].downcast_ref::() { - if let ScalarValue::Utf8(Some(field_name)) = literal.value() { - // Get the column from the second argument - if let Some(column) = func.args()[1] - .as_any() - .downcast_ref::() - { - let column_name = column.name(); - // Check if there's a flat column with underscore prefix - let flat_column_name = format!("_{column_name}.{field_name}"); - - if let Ok(flat_field_index) = - physical_file_schema.index_of(&flat_column_name) - { - let flat_field = - physical_file_schema.field(flat_field_index); - - if flat_field.data_type() == &DataType::Utf8 { - // Replace the whole expression with a direct column reference - let new_expr = Arc::new(expressions::Column::new( - &flat_column_name, - flat_field_index, - )) - as Arc; - - return Ok(Transformed { - data: new_expr, - tnr: TreeNodeRecursion::Stop, - transformed: true, - }); - } - } + let column_name = column.name(); + // Check if there's a flat column with underscore prefix + let flat_column_name = format!("_{column_name}.{field_name}"); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) + { + let flat_field = physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); } } } diff --git a/datafusion-examples/examples/data_io/main.rs b/datafusion-examples/examples/data_io/main.rs new file mode 100644 index 0000000000000..4656a83670aaf --- /dev/null +++ b/datafusion-examples/examples/data_io/main.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of data formats and I/O +//! +//! These examples demonstrate data formats and I/O. +//! +//! ## Usage +//! ```bash +//! cargo run --example data_io -- [all|catalog|in_memory_object_store|json_shredding|parquet_adv_idx|parquet_emb_idx|parquet_enc_with_kms|parquet_enc|parquet_exec_visitor|parquet_idx|query_http_csv|remote_catalog] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `catalog` +//! (file: catalog.rs, desc: Register tables into a custom catalog) +//! +//! - `in_memory_object_store` +//! (file: in_memory_object_store.rs, desc: Read CSV from an in-memory object store (pattern applies to JSON/Parquet)) +//! +//! - `json_shredding` +//! (file: json_shredding.rs, desc: Implement filter rewriting for JSON shredding) +//! +//! - `parquet_adv_idx` +//! (file: parquet_advanced_index.rs, desc: Create a secondary index across multiple parquet files) +//! +//! - `parquet_emb_idx` +//! (file: parquet_embedded_index.rs, desc: Store a custom index inside Parquet files) +//! +//! - `parquet_enc` +//! (file: parquet_encrypted.rs, desc: Read & write encrypted Parquet files) +//! +//! - `parquet_enc_with_kms` +//! (file: parquet_encrypted_with_kms.rs, desc: Encrypted Parquet I/O using a KMS-backed factory) +//! +//! - `parquet_exec_visitor` +//! (file: parquet_exec_visitor.rs, desc: Extract statistics by visiting an ExecutionPlan) +//! +//! - `parquet_idx` +//! (file: parquet_index.rs, desc: Create a secondary index) +//! +//! - `query_http_csv` +//! (file: query_http_csv.rs, desc: Query CSV files via HTTP) +//! +//! - `remote_catalog` +//! (file: remote_catalog.rs, desc: Interact with a remote catalog) + +mod catalog; +mod in_memory_object_store; +mod json_shredding; +mod parquet_advanced_index; +mod parquet_embedded_index; +mod parquet_encrypted; +mod parquet_encrypted_with_kms; +mod parquet_exec_visitor; +mod parquet_index; +mod query_http_csv; +mod remote_catalog; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Catalog, + InMemoryObjectStore, + JsonShredding, + ParquetAdvIdx, + ParquetEmbIdx, + ParquetEnc, + ParquetEncWithKms, + ParquetExecVisitor, + ParquetIdx, + QueryHttpCsv, + RemoteCatalog, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "data_io"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Catalog => catalog::catalog().await?, + ExampleKind::InMemoryObjectStore => { + in_memory_object_store::in_memory_object_store().await? + } + ExampleKind::JsonShredding => json_shredding::json_shredding().await?, + ExampleKind::ParquetAdvIdx => { + parquet_advanced_index::parquet_advanced_index().await? + } + ExampleKind::ParquetEmbIdx => { + parquet_embedded_index::parquet_embedded_index().await? + } + ExampleKind::ParquetEncWithKms => { + parquet_encrypted_with_kms::parquet_encrypted_with_kms().await? + } + ExampleKind::ParquetEnc => parquet_encrypted::parquet_encrypted().await?, + ExampleKind::ParquetExecVisitor => { + parquet_exec_visitor::parquet_exec_visitor().await? + } + ExampleKind::ParquetIdx => parquet_index::parquet_index().await?, + ExampleKind::QueryHttpCsv => query_http_csv::query_http_csv().await?, + ExampleKind::RemoteCatalog => remote_catalog::remote_catalog().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/data_io/parquet_advanced_index.rs similarity index 97% rename from datafusion-examples/examples/advanced_parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_advanced_index.rs index 1c560be6d08a6..9e69c7f15a841 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_advanced_index.rs @@ -15,40 +15,41 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +//! See `main.rs` for how to run it. + use std::collections::{HashMap, HashSet}; use std::fs::File; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use datafusion::catalog::Session; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; use datafusion::datasource::physical_plan::{ FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, }; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::arrow::arrow_reader::{ ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, }; use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use datafusion::parquet::arrow::ArrowWriter; -use datafusion::parquet::file::metadata::ParquetMetaData; +use datafusion::parquet::file::metadata::{PageIndexPolicy, ParquetMetaData}; use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; use datafusion::parquet::schema::types::ColumnPath; -use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_optimizer::pruning::PruningPredicate; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::*; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; @@ -56,8 +57,8 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use bytes::Bytes; use datafusion::datasource::memory::DataSourceExec; -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use object_store::ObjectStore; use tempfile::TempDir; use url::Url; @@ -121,7 +122,6 @@ use url::Url; /// │ ╚═══════════════════╝ │ 1. With cached ParquetMetadata, so /// └───────────────────────┘ the ParquetSource does not re-read / /// Parquet File decode the thrift footer -/// /// ``` /// /// Within a Row Group, Column Chunks store data in DataPages. This example also @@ -156,8 +156,7 @@ use url::Url; /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable /// [Page Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_advanced_index() -> Result<()> { // the object store is used to read the parquet files (in this case, it is // a local file system, but in a real system it could be S3, GCS, etc) let object_store: Arc = @@ -240,6 +239,7 @@ pub struct IndexTableProvider { /// if true, use row selections in addition to row group selections use_row_selections: AtomicBool, } + impl IndexTableProvider { /// Create a new IndexTableProvider /// * `object_store` - the object store implementation to use for reading files @@ -409,7 +409,7 @@ impl IndexedFile { let options = ArrowReaderOptions::new() // Load the page index when reading metadata to cache // so it is available to interpret row selections - .with_page_index(true); + .with_page_index_policy(PageIndexPolicy::Required); let reader = ParquetRecordBatchReaderBuilder::try_new_with_options(file, options)?; let metadata = reader.metadata().clone(); @@ -450,10 +450,6 @@ impl IndexedFile { /// so that we can query it as a table. #[async_trait] impl TableProvider for IndexTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.indexed_file.schema) } @@ -492,19 +488,18 @@ impl TableProvider for IndexTableProvider { .with_file(indexed_file); let file_source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) // provide the predicate so the DataSourceExec can try and prune // row groups internally .with_predicate(predicate) // provide the factory to create parquet reader without re-reading metadata .with_parquet_file_reader_factory(Arc::new(reader_factory)), ); - let file_scan_config = - FileScanConfigBuilder::new(object_store_url, schema, file_source) - .with_limit(limit) - .with_projection_indices(projection.cloned()) - .with_file(partitioned_file) - .build(); + let file_scan_config = FileScanConfigBuilder::new(object_store_url, file_source) + .with_limit(limit) + .with_projection_indices(projection.cloned())? + .with_file(partitioned_file) + .build(); // Finally, put it all together into a DataSourceExec Ok(DataSourceExec::from_data_source(file_scan_config)) @@ -541,6 +536,7 @@ impl CachedParquetFileReaderFactory { metadata: HashMap::new(), } } + /// Add the pre-parsed information about the file to the factor fn with_file(mut self, indexed_file: &IndexedFile) -> Self { self.metadata.insert( @@ -566,7 +562,7 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { .object_meta .location .parts() - .last() + .next_back() .expect("No path in location") .as_ref() .to_string(); @@ -658,7 +654,7 @@ fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> // enable page statistics for the tag column, // for everything else. let props = WriterProperties::builder() - .set_max_row_group_size(100) + .set_max_row_group_row_count(Some(100)) // compute column chunk (per row group) statistics by default .set_statistics_enabled(EnabledStatistics::Chunk) // compute column page statistics for the tag column diff --git a/datafusion-examples/examples/parquet_embedded_index.rs b/datafusion-examples/examples/data_io/parquet_embedded_index.rs similarity index 94% rename from datafusion-examples/examples/parquet_embedded_index.rs rename to datafusion-examples/examples/data_io/parquet_embedded_index.rs index 3cbe189147752..40b5b468ff5bf 100644 --- a/datafusion-examples/examples/parquet_embedded_index.rs +++ b/datafusion-examples/examples/data_io/parquet_embedded_index.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! Embedding and using a custom index in Parquet files //! //! # Background @@ -116,11 +118,11 @@ use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::{exec_err, HashMap, HashSet, Result}; +use datafusion::common::{HashMap, HashSet, Result, exec_err}; +use datafusion::datasource::TableType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::DataSourceExec; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::datasource::TableType; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::{Operator, TableProviderFilterPushDown}; use datafusion::parquet::arrow::ArrowWriter; @@ -130,12 +132,37 @@ use datafusion::parquet::file::reader::{FileReader, SerializedFileReader}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; -use std::fs::{read_dir, File}; +use std::fs::{File, read_dir}; use std::io::{Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use tempfile::TempDir; +/// Store a custom index inside a Parquet file and use it to speed up queries +pub async fn parquet_embedded_index() -> Result<()> { + // 1. Create temp dir and write 3 Parquet files with different category sets + let tmp = TempDir::new()?; + let dir = tmp.path(); + write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; + write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; + write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; + + // 2. Register our custom TableProvider + let field = Field::new("category", DataType::Utf8, false); + let schema_ref = Arc::new(Schema::new(vec![field])); + let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); + + let ctx = SessionContext::new(); + ctx.register_table("t", provider)?; + + // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. + // based on the distinct index. + let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; + df.show().await?; + + Ok(()) +} + /// An index of distinct values for a single column /// /// In this example the index is a simple set of strings, but in a real @@ -366,9 +393,6 @@ fn get_key_value<'a>(file_meta_data: &'a FileMetaData, key: &'_ str) -> Option<& /// Implement TableProvider for DistinctIndexTable, using the distinct index to prune files #[async_trait] impl TableProvider for DistinctIndexTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -392,21 +416,15 @@ impl TableProvider for DistinctIndexTable { // equality analysis or write your own custom logic. let mut target: Option<&str> = None; - if filters.len() == 1 { - if let Expr::BinaryExpr(expr) = &filters[0] { - if expr.op == Operator::Eq { - if let ( - Expr::Column(c), - Expr::Literal(ScalarValue::Utf8(Some(v)), _), - ) = (&*expr.left, &*expr.right) - { - if c.name == "category" { - println!("Filtering for category: {v}"); - target = Some(v); - } - } - } - } + if filters.len() == 1 + && let Expr::BinaryExpr(expr) = &filters[0] + && expr.op == Operator::Eq + && let (Expr::Column(c), Expr::Literal(ScalarValue::Utf8(Some(v)), _)) = + (&*expr.left, &*expr.right) + && c.name == "category" + { + println!("Filtering for category: {v}"); + target = Some(v); } // Determine which files to scan let files_to_scan: Vec<_> = self @@ -426,8 +444,10 @@ impl TableProvider for DistinctIndexTable { // Build ParquetSource to actually read the files let url = ObjectStoreUrl::parse("file://")?; - let source = Arc::new(ParquetSource::default().with_enable_page_index(true)); - let mut builder = FileScanConfigBuilder::new(url, self.schema.clone(), source); + let source = Arc::new( + ParquetSource::new(self.schema.clone()).with_enable_page_index(true), + ); + let mut builder = FileScanConfigBuilder::new(url, source); for file in files_to_scan { let path = self.dir.join(file); let len = std::fs::metadata(&path)?.len(); @@ -450,28 +470,3 @@ impl TableProvider for DistinctIndexTable { Ok(vec![TableProviderFilterPushDown::Inexact; fs.len()]) } } - -#[tokio::main] -async fn main() -> Result<()> { - // 1. Create temp dir and write 3 Parquet files with different category sets - let tmp = TempDir::new()?; - let dir = tmp.path(); - write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; - write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; - write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; - - // 2. Register our custom TableProvider - let field = Field::new("category", DataType::Utf8, false); - let schema_ref = Arc::new(Schema::new(vec![field])); - let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); - - let ctx = SessionContext::new(); - ctx.register_table("t", provider)?; - - // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. - // based on the distinct index. - let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/parquet_encrypted.rs b/datafusion-examples/examples/data_io/parquet_encrypted.rs similarity index 75% rename from datafusion-examples/examples/parquet_encrypted.rs rename to datafusion-examples/examples/data_io/parquet_encrypted.rs index 690d9f2a5f140..f73c538d1c4d9 100644 --- a/datafusion-examples/examples/parquet_encrypted.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::sync::Arc; + use datafusion::common::DataFusionError; use datafusion::config::{ConfigFileEncryptionProperties, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; @@ -22,21 +26,21 @@ use datafusion::logical_expr::{col, lit}; use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; use datafusion::prelude::{ParquetReadOptions, SessionContext}; -use std::sync::Arc; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use tempfile::TempDir; -#[tokio::main] -async fn main() -> datafusion::common::Result<()> { +/// Read and write encrypted Parquet files using DataFusion +pub async fn parquet_encrypted() -> datafusion::common::Result<()> { // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); - // Find the local path of "alltypes_plain.parquet" - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Read the sample parquet file let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // Show information from the dataframe @@ -50,30 +54,33 @@ async fn main() -> datafusion::common::Result<()> { let (encrypt, decrypt) = setup_encryption(&parquet_df)?; // Create a temporary file location for the encrypted parquet file - let tmp_dir = TempDir::new()?; - let tempfile = tmp_dir.path().join("alltypes_plain-encrypted.parquet"); - let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + let tmp_source = TempDir::new()?; + let tempfile = tmp_source.path().join("cars_encrypted.parquet"); // Write encrypted parquet let mut options = TableParquetOptions::default(); options.crypto.file_encryption = Some(ConfigFileEncryptionProperties::from(&encrypt)); parquet_df .write_parquet( - tempfile_str.as_str(), + tempfile.to_str().unwrap(), DataFrameWriteOptions::new().with_single_file_output(true), Some(options), ) .await?; - // Read encrypted parquet + // Read encrypted parquet back as a DataFrame using matching decryption config let ctx: SessionContext = SessionContext::new(); let read_options = - ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + ParquetReadOptions::default().file_decryption_properties((&decrypt).try_into()?); - let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + let encrypted_parquet_df = ctx + .read_parquet(tempfile.to_str().unwrap(), read_options) + .await?; // Show information from the dataframe - println!("\n\n==============================================================================="); + println!( + "\n\n===============================================================================" + ); println!("Encrypted Parquet DataFrame:"); query_dataframe(&encrypted_parquet_df).await?; @@ -87,11 +94,12 @@ async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { df.clone().describe().await?.show().await?; // Select three columns and filter the results - // so that only rows where id > 1 are returned + // so that only rows where speed > 5 are returned + // select car, speed, time from t where speed > 5 println!("\nSelected rows and columns:"); df.clone() - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(5)))? .show() .await?; diff --git a/datafusion-examples/examples/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs similarity index 99% rename from datafusion-examples/examples/parquet_encrypted_with_kms.rs rename to datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs index 45bfd183773a0..1a9bf56c09b35 100644 --- a/datafusion-examples/examples/parquet_encrypted_with_kms.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow_schema::SchemaRef; use async_trait::async_trait; @@ -53,8 +55,7 @@ const ENCRYPTION_FACTORY_ID: &str = "example.mock_kms_encryption"; /// which is not a secure way to store encryption keys. /// For production use, it is recommended to use a key-management service (KMS) to encrypt /// data encryption keys. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_encrypted_with_kms() -> Result<()> { let ctx = SessionContext::new(); // Register an `EncryptionFactory` implementation to be used for Parquet encryption diff --git a/datafusion-examples/examples/parquet_exec_visitor.rs b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs similarity index 72% rename from datafusion-examples/examples/parquet_exec_visitor.rs rename to datafusion-examples/examples/data_io/parquet_exec_visitor.rs index 84f92d4f450e1..d1951b2d9904d 100644 --- a/datafusion-examples/examples/parquet_exec_visitor.rs +++ b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use datafusion::datasource::file_format::parquet::ParquetFormat; @@ -25,34 +27,37 @@ use datafusion::error::DataFusionError; use datafusion::execution::context::SessionContext; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::{ - execute_stream, visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, + ExecutionPlan, ExecutionPlanVisitor, execute_stream, visit_execution_plan, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::StreamExt; /// Example of collecting metrics after execution by visiting the `ExecutionPlan` -#[tokio::main] -async fn main() { +pub async fn parquet_exec_visitor() -> datafusion::common::Result<()> { let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)); + let table_path = parquet_temp.file_uri()?; + // First example were we use an absolute path, which requires no additional setup. - let _ = ctx - .register_listing_table( - "my_table", - &format!("file://{test_data}/alltypes_plain.parquet"), - listing_options.clone(), - None, - None, - ) - .await; - - let df = ctx.sql("SELECT * FROM my_table").await.unwrap(); - let plan = df.create_physical_plan().await.unwrap(); + ctx.register_listing_table( + "my_table", + &table_path, + listing_options.clone(), + None, + None, + ) + .await?; + + let df = ctx.sql("SELECT * FROM my_table").await?; + let plan = df.create_physical_plan().await?; // Create empty visitor let mut visitor = ParquetExecVisitor { @@ -63,12 +68,12 @@ async fn main() { // Make sure you execute the plan to collect actual execution statistics. // For example, in this example the `file_scan_config` is known without executing // but the `bytes_scanned` would be None if we did not execute. - let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx()).unwrap(); + let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx())?; while let Some(batch) = batch_stream.next().await { println!("Batch rows: {}", batch.unwrap().num_rows()); } - visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + visit_execution_plan(plan.as_ref(), &mut visitor)?; println!( "ParquetExecVisitor bytes_scanned: {:?}", @@ -78,6 +83,8 @@ async fn main() { "ParquetExecVisitor file_groups: {:?}", visitor.file_groups.unwrap() ); + + Ok(()) } /// Define a struct with fields to hold the execution information you want to @@ -97,18 +104,17 @@ impl ExecutionPlanVisitor for ParquetExecVisitor { /// or `post_visit` (visit each node after its children/inputs) fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { // If needed match on a specific `ExecutionPlan` node type - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if let Some((file_config, _)) = + if let Some(data_source_exec) = plan.downcast_ref::() + && let Some((file_config, _)) = data_source_exec.downcast_to_file_source::() - { - self.file_groups = Some(file_config.file_groups.clone()); - - let metrics = match data_source_exec.metrics() { - None => return Ok(true), - Some(metrics) => metrics, - }; - self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); - } + { + self.file_groups = Some(file_config.file_groups.clone()); + + let metrics = match data_source_exec.metrics() { + None => return Ok(true), + Some(metrics) => metrics, + }; + self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); } Ok(true) } diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/data_io/parquet_index.rs similarity index 97% rename from datafusion-examples/examples/parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_index.rs index 127c55da982c8..9be84d8249342 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_index.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ Array, ArrayRef, AsArray, BooleanArray, Int32Array, RecordBatch, StringArray, UInt64Array, @@ -25,33 +27,32 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::DataSourceExec; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::{ - utils::conjunction, TableProviderFilterPushDown, TableType, + TableProviderFilterPushDown, TableType, utils::conjunction, }; use datafusion::parquet::arrow::arrow_reader::statistics::StatisticsConverter; use datafusion::parquet::arrow::{ - arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, + ArrowWriter, arrow_reader::ParquetRecordBatchReaderBuilder, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; -use std::any::Any; use std::collections::HashSet; use std::fmt::Display; use std::fs; use std::fs::{DirEntry, File}; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tempfile::TempDir; use url::Url; @@ -99,12 +100,10 @@ use url::Url; /// Thus some parquet files are │ │ /// "pruned" and thus are not └─────────────┘ /// scanned at all Parquet Files -/// /// ``` /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_index() -> Result<()> { // Demo data has three files, each with schema // * file_name (string) // * value (int32) @@ -208,10 +207,6 @@ impl IndexTableProvider { #[async_trait] impl TableProvider for IndexTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.index.schema().clone() } @@ -243,10 +238,11 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = Arc::new(ParquetSource::default().with_predicate(predicate)); + let source = + Arc::new(ParquetSource::new(self.schema()).with_predicate(predicate)); let mut file_scan_config_builder = - FileScanConfigBuilder::new(object_store_url, self.schema(), source) - .with_projection_indices(projection.cloned()) + FileScanConfigBuilder::new(object_store_url, source) + .with_projection_indices(projection.cloned())? .with_limit(limit); // Transform to the format needed to pass to DataSourceExec @@ -461,7 +457,7 @@ impl PruningStatistics for ParquetMetadataIndex { } /// return the row counts for each file - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self) -> Option { Some(self.row_counts_ref().clone()) } @@ -510,7 +506,7 @@ impl ParquetMetadataIndexBuilder { // Get the schema of the file. A real system might have to handle the // case where the schema of the file is not the same as the schema of - // the other files e.g. using SchemaAdapter. + // the other files e.g. using PhysicalExprAdapterFactory. if self.file_schema.is_none() { self.file_schema = Some(reader.schema().clone()); } diff --git a/datafusion-examples/examples/query-http-csv.rs b/datafusion-examples/examples/data_io/query_http_csv.rs similarity index 91% rename from datafusion-examples/examples/query-http-csv.rs rename to datafusion-examples/examples/data_io/query_http_csv.rs index fa3fd2ac068df..71421e6270ccb 100644 --- a/datafusion-examples/examples/query-http-csv.rs +++ b/datafusion-examples/examples/data_io/query_http_csv.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::http::HttpBuilder; use std::sync::Arc; use url::Url; -/// This example demonstrates executing a simple query against an Arrow data source (CSV) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { +/// Configure `object_store` and run a query against files via HTTP +pub async fn query_http_csv() -> Result<()> { // create local execution context let ctx = SessionContext::new(); diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/data_io/remote_catalog.rs similarity index 98% rename from datafusion-examples/examples/remote_catalog.rs rename to datafusion-examples/examples/data_io/remote_catalog.rs index 74575554ec0af..16814752b3ec2 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/data_io/remote_catalog.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// This example shows how to implement the DataFusion [`CatalogProvider`] API /// for catalogs that are remote (require network access) and/or offer only /// asynchronous APIs such as [Polaris], [Unity], and [Hive]. @@ -39,15 +41,14 @@ use datafusion::common::{assert_batches_eq, internal_datafusion_err, plan_err}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::SendableRecordBatchStream; use datafusion::logical_expr::{Expr, TableType}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::prelude::{DataFrame, SessionContext}; use futures::TryStreamExt; -use std::any::Any; use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<()> { +/// Interfacing with a remote catalog (e.g. over a network) +pub async fn remote_catalog() -> Result<()> { // As always, we create a session context to interact with DataFusion let ctx = SessionContext::new(); @@ -222,10 +223,6 @@ impl RemoteTable { /// Implement the DataFusion Catalog API for [`RemoteTable`] #[async_trait] impl TableProvider for RemoteTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/datafusion-examples/examples/dataframe/cache_factory.rs b/datafusion-examples/examples/dataframe/cache_factory.rs new file mode 100644 index 0000000000000..a92c3dc4ce26a --- /dev/null +++ b/datafusion-examples/examples/dataframe/cache_factory.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::{Arc, RwLock}; + +use arrow::array::RecordBatch; +use async_trait::async_trait; +use datafusion::catalog::memory::MemorySourceConfig; +use datafusion::common::DFSchemaRef; +use datafusion::error::Result; +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::CacheFactory; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, +}; +use datafusion::physical_plan::{ExecutionPlan, collect_partitioned}; +use datafusion::physical_planner::{ + DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, +}; +use datafusion::prelude::*; +use datafusion_common::HashMap; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; + +/// This example demonstrates how to leverage [CacheFactory] to implement custom caching strategies for dataframes in DataFusion. +/// By default, [DataFrame::cache] in Datafusion is eager and creates an in-memory table. This example shows a basic alternative implementation for lazy caching. +/// Specifically, it implements: +/// - A [CustomCacheFactory] that creates a logical node [CacheNode] representing the cache operation. +/// - A [CacheNodePlanner] (an [ExtensionPlanner]) that understands [CacheNode] and performs caching. +/// - A [CacheNodeQueryPlanner] that installs [CacheNodePlanner]. +/// - A simple in-memory [CacheManager] that stores cached [RecordBatch]es. Note that the implementation for this example is very naive and only implements put, but for real production use cases cache eviction and drop should also be implemented. +pub async fn cache_dataframe_with_custom_logic() -> Result<()> { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(CustomCacheFactory {}))) + .with_query_planner(Arc::new(CacheNodeQueryPlanner::default())) + .build(); + let ctx = SessionContext::new_with_state(session_state); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + // Read the parquet files and show its schema using 'describe' + let parquet_df = ctx + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) + .await?; + + let df_cached = parquet_df + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1.0)))? + .cache() + .await?; + + let df1 = df_cached.clone().filter(col("car").eq(lit("red")))?; + let df2 = df1.clone().sort(vec![col("car").sort(true, false)])?; + + // should see log for caching only once + df_cached.show().await?; + df1.show().await?; + df2.show().await?; + + Ok(()) +} + +#[derive(Debug)] +struct CustomCacheFactory {} + +impl CacheFactory for CustomCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +#[derive(PartialEq, Eq, PartialOrd, Hash, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size must be one"); + Ok(Self { + input: inputs.swap_remove(0), + }) + } +} + +struct CacheNodePlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl ExtensionPlanner for CacheNodePlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + if let Some(cache_node) = node.as_any().downcast_ref::() { + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + if self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .is_none() + { + let ctx = session_state.task_ctx(); + println!("caching in memory"); + let batches = + collect_partitioned(physical_inputs[0].clone(), ctx).await?; + self.cache_manager + .write() + .unwrap() + .put(cache_node.input.clone(), batches); + } else { + println!("fetching directly from cache manager"); + } + Ok(self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .map(|batches| { + let exec: Arc = MemorySourceConfig::try_new_exec( + batches, + physical_inputs[0].schema(), + None, + ) + .unwrap(); + exec + })) + } else { + Ok(None) + } + } +} + +#[derive(Debug, Default)] +struct CacheNodeQueryPlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl QueryPlanner for CacheNodeQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + CacheNodePlanner { + cache_manager: Arc::clone(&self.cache_manager), + }, + )]); + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +// This naive implementation only includes put, but for real production use cases cache eviction and drop should also be implemented. +#[derive(Debug, Default)] +struct CacheManager { + cache: HashMap>>, +} + +impl CacheManager { + pub fn put(&mut self, k: LogicalPlan, v: Vec>) { + self.cache.insert(k, v); + } + + pub fn get(&self, k: &LogicalPlan) -> Option<&Vec>> { + self.cache.get(k) + } +} diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe/dataframe.rs similarity index 73% rename from datafusion-examples/examples/dataframe.rs rename to datafusion-examples/examples/dataframe/dataframe.rs index a5ee571a14764..dde19cb476f14 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe/dataframe.rs @@ -15,22 +15,26 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::fs::File; +use std::io::Write; +use std::sync::Arc; + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::catalog::MemTable; +use datafusion::common::ScalarValue; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; -use datafusion::common::DataFusionError; -use datafusion::common::ScalarValue; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; -use std::fs::File; -use std::io::Write; -use std::sync::Arc; -use tempfile::tempdir; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tempfile::{TempDir, tempdir}; +use tokio::fs::create_dir_all; /// This example demonstrates using DataFusion's DataFrame API /// @@ -39,6 +43,7 @@ use tempfile::tempdir; /// * [read_parquet]: execute queries against parquet files /// * [read_csv]: execute queries against csv files /// * [read_memory]: execute queries against in-memory arrow data +/// * [read_memory_macro]: execute queries against in-memory arrow data using macro /// /// # Writing out to local storage /// @@ -53,12 +58,7 @@ use tempfile::tempdir; /// * [where_scalar_subquery]: execute a scalar subquery /// * [where_in_subquery]: execute a subquery with an IN clause /// * [where_exist_subquery]: execute a subquery with an EXISTS clause -/// -/// # Querying data -/// -/// * [query_to_date]: execute queries against parquet files -#[tokio::main] -async fn main() -> Result<()> { +pub async fn dataframe_example() -> Result<()> { env_logger::init(); // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); @@ -67,8 +67,8 @@ async fn main() -> Result<()> { read_memory(&ctx).await?; read_memory_macro().await?; write_out(&ctx).await?; - register_aggregate_test_data("t1", &ctx).await?; - register_aggregate_test_data("t2", &ctx).await?; + register_cars_test_data("t1", &ctx).await?; + register_cars_test_data("t2", &ctx).await?; where_scalar_subquery(&ctx).await?; where_in_subquery(&ctx).await?; where_exist_subquery(&ctx).await?; @@ -80,23 +80,24 @@ async fn main() -> Result<()> { /// 2. Show the schema /// 3. Select columns and rows async fn read_parquet(ctx: &SessionContext) -> Result<()> { - // Find the local path of "alltypes_plain.parquet" - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(ctx, &dataset.path()).await?; // Read the parquet files and show its schema using 'describe' let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // show its schema using 'describe' parquet_df.clone().describe().await?.show().await?; // Select three columns and filter the results - // so that only rows where id > 1 are returned + // so that only rows where speed > 1 are returned + // select car, speed, time from t where speed > 1 parquet_df - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(1)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1)))? .show() .await?; @@ -199,7 +200,7 @@ async fn read_memory_macro() -> Result<()> { /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file -async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { +async fn write_out(ctx: &SessionContext) -> Result<()> { let array = StringViewArray::from(vec!["a", "b", "c"]); let schema = Arc::new(Schema::new(vec![Field::new( "tablecol1", @@ -211,15 +212,26 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr ctx.register_table("initial_data", Arc::new(mem_table))?; let df = ctx.table("initial_data").await?; - ctx.sql( - "create external table - test(tablecol1 varchar) - stored as parquet - location './datafusion-examples/test_table/'", - ) - .await? - .collect() - .await?; + // Create a single temp root with subdirectories + let tmp_root = TempDir::new()?; + let examples_root = tmp_root.path().join("datafusion-examples"); + create_dir_all(&examples_root).await?; + let table_dir = examples_root.join("test_table"); + let parquet_dir = examples_root.join("test_parquet"); + let csv_dir = examples_root.join("test_csv"); + let json_dir = examples_root.join("test_json"); + create_dir_all(&table_dir).await?; + create_dir_all(&parquet_dir).await?; + create_dir_all(&csv_dir).await?; + create_dir_all(&json_dir).await?; + + let create_sql = format!( + "CREATE EXTERNAL TABLE test(tablecol1 varchar) + STORED AS parquet + LOCATION '{}'", + table_dir.display() + ); + ctx.sql(&create_sql).await?.collect().await?; // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). // The behavior of write_table depends on the TableProvider's implementation @@ -230,7 +242,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_parquet( - "./datafusion-examples/test_parquet/", + parquet_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) @@ -238,7 +250,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_csv( - "./datafusion-examples/test_csv/", + csv_dir.to_str().unwrap(), // DataFrameWriteOptions contains options which control how data is written // such as compression codec DataFrameWriteOptions::new(), @@ -248,7 +260,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_json( - "./datafusion-examples/test_json/", + json_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) @@ -258,7 +270,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr } /// Use the DataFrame API to execute the following subquery: -/// select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; +/// select car, speed from t1 where (select avg(t2.speed) from t2 where t1.car = t2.car) > 0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -266,14 +278,14 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .aggregate(vec![], vec![avg(col("t2.c2"))])? - .select(vec![avg(col("t2.c2"))])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .aggregate(vec![], vec![avg(col("t2.speed"))])? + .select(vec![avg(col("t2.speed"))])? .into_unoptimized_plan(), )) - .gt(lit(0u8)), + .gt(lit(0.0)), )? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -281,22 +293,24 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; +/// select t1.car, t1.speed from t1 where t1.speed in (select max(t2.speed) from t2 where t2.car = 'red') limit 3; async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(in_subquery( - col("t1.c2"), + col("t1.speed"), Arc::new( ctx.table("t2") .await? - .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? - .aggregate(vec![], vec![max(col("t2.c2"))])? - .select(vec![max(col("t2.c2"))])? + .filter( + col("t2.car").eq(lit(ScalarValue::Utf8(Some("red".to_string())))), + )? + .aggregate(vec![], vec![max(col("t2.speed"))])? + .select(vec![max(col("t2.speed"))])? .into_unoptimized_plan(), ), ))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -304,31 +318,27 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where exists (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; +/// select t1.car, t1.speed from t1 where exists (select t2.speed from t2 where t1.car = t2.car) limit 3; async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .select(vec![col("t2.c2")])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .select(vec![col("t2.speed")])? .into_unoptimized_plan(), )))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; Ok(()) } -async fn register_aggregate_test_data(name: &str, ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); - ctx.register_csv( - name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::default(), - ) - .await?; +async fn register_cars_test_data(name: &str, ctx: &SessionContext) -> Result<()> { + let dataset = ExampleDataset::Cars; + ctx.register_csv(name, dataset.path_str()?, CsvReadOptions::default()) + .await?; Ok(()) } diff --git a/datafusion-examples/examples/dataframe/deserialize_to_struct.rs b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs new file mode 100644 index 0000000000000..b031225dc9b69 --- /dev/null +++ b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use arrow::array::{Array, Float64Array, StringViewArray}; +use datafusion::common::assert_batches_eq; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use futures::StreamExt; + +/// This example shows how to convert query results into Rust structs by using +/// the Arrow APIs to convert the results into Rust native types. +/// +/// This is a bit tricky initially as the results are returned as columns stored +/// as [ArrayRef] +/// +/// [ArrayRef]: arrow::array::ArrayRef +pub async fn deserialize_to_struct() -> Result<()> { + // Run a query that returns two columns of data + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + ctx.register_parquet( + "cars", + parquet_temp.path_str()?, + ParquetReadOptions::default(), + ) + .await?; + + let df = ctx + .sql("SELECT car, speed FROM cars ORDER BY speed LIMIT 50") + .await?; + + // print out the results showing we have car and speed columns and a deterministic ordering + let results = df.clone().collect().await?; + assert_batches_eq!( + [ + "+-------+-------+", + "| car | speed |", + "+-------+-------+", + "| red | 0.0 |", + "| red | 1.0 |", + "| green | 2.0 |", + "| red | 3.0 |", + "| red | 7.0 |", + "| red | 7.1 |", + "| red | 7.2 |", + "| green | 8.0 |", + "| green | 10.0 |", + "| green | 10.3 |", + "| green | 10.4 |", + "| green | 10.5 |", + "| green | 11.0 |", + "| green | 12.0 |", + "| green | 14.0 |", + "| green | 15.0 |", + "| green | 15.1 |", + "| green | 15.2 |", + "| red | 17.0 |", + "| red | 18.0 |", + "| red | 19.0 |", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "+-------+-------+", + ], + &results + ); + + // We will now convert the query results into a Rust struct + let mut stream = df.execute_stream().await?; + let mut list: Vec = vec![]; + + // DataFusion produces data in chunks called `RecordBatch`es which are + // typically 8000 rows each. This loop processes each `RecordBatch` as it is + // produced by the query plan and adds it to the list + while let Some(batch) = stream.next().await.transpose()? { + // Each `RecordBatch` has one or more columns. Each column is stored as + // an `ArrayRef`. To interact with data using Rust native types we need to + // convert these `ArrayRef`s into concrete array types using APIs from + // the arrow crate. + + // In this case, we know that each batch has two columns of the Arrow + // types StringView and Float64, so first we cast the two columns to the + // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: + let car_col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("car column must be Utf8View"); + + let speed_col = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("speed column must be Float64"); + + // With PrimitiveArrays, we can access to the values as native Rust + // types String and f64, and forming the desired `Data` structs + for i in 0..batch.num_rows() { + let car = if car_col.is_null(i) { + None + } else { + Some(car_col.value(i).to_string()) + }; + + let speed = if speed_col.is_null(i) { + None + } else { + Some(speed_col.value(i)) + }; + + list.push(Data { car, speed }); + } + } + + // Finally, we have the results in the list of Rust structs + let res = format!("{list:#?}"); + assert_eq!( + res, + r#"[ + Data { + car: Some( + "red", + ), + speed: Some( + 0.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 1.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 2.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 3.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.1, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.2, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 8.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.3, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.4, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.5, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 11.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 12.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 14.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.1, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.2, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 17.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 18.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 19.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 20.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 20.3, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 21.4, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 21.5, + ), + }, +]"# + ); + + let speed_green_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("green")) + .filter_map(|data| data.speed) + .sum(); + let speed_red_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("red")) + .filter_map(|data| data.speed) + .sum(); + assert_eq!(speed_green_sum, 133.5); + assert_eq!(speed_red_sum, 162.5); + + Ok(()) +} + +/// This is target struct where we want the query results. +#[derive(Debug)] +struct Data { + car: Option, + speed: Option, +} diff --git a/datafusion-examples/examples/dataframe/main.rs b/datafusion-examples/examples/dataframe/main.rs new file mode 100644 index 0000000000000..25b5377d38239 --- /dev/null +++ b/datafusion-examples/examples/dataframe/main.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are core DataFrame API usage +//! +//! These examples demonstrate core DataFrame API usage. +//! +//! ## Usage +//! ```bash +//! cargo run --example dataframe -- [all|dataframe|deserialize_to_struct|cache_factory] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `cache_factory` +//! (file: cache_factory.rs, desc: Custom lazy caching for DataFrames using `CacheFactory`) +// +//! - `dataframe` +//! (file: dataframe.rs, desc: Query DataFrames from various sources and write output) +//! +//! - `deserialize_to_struct` +//! (file: deserialize_to_struct.rs, desc: Convert Arrow arrays into Rust structs) + +mod cache_factory; +mod dataframe; +mod deserialize_to_struct; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Dataframe, + DeserializeToStruct, + CacheFactory, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "dataframe"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Dataframe => { + dataframe::dataframe_example().await?; + } + ExampleKind::DeserializeToStruct => { + deserialize_to_struct::deserialize_to_struct().await?; + } + ExampleKind::CacheFactory => { + cache_factory::cache_dataframe_with_custom_logic().await?; + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs deleted file mode 100644 index d6655b3b654f9..0000000000000 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{AsArray, PrimitiveArray}; -use arrow::datatypes::{Float64Type, Int32Type}; -use datafusion::common::assert_batches_eq; -use datafusion::error::Result; -use datafusion::prelude::*; -use futures::StreamExt; - -/// This example shows how to convert query results into Rust structs by using -/// the Arrow APIs to convert the results into Rust native types. -/// -/// This is a bit tricky initially as the results are returned as columns stored -/// as [ArrayRef] -/// -/// [ArrayRef]: arrow::array::ArrayRef -#[tokio::main] -async fn main() -> Result<()> { - // Run a query that returns two columns of data - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; - - // print out the results showing we have an int32 and a float64 column - let results = df.clone().collect().await?; - assert_batches_eq!( - [ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "+---------+------------+", - ], - &results - ); - - // We will now convert the query results into a Rust struct - let mut stream = df.execute_stream().await?; - let mut list = vec![]; - - // DataFusion produces data in chunks called `RecordBatch`es which are - // typically 8000 rows each. This loop processes each `RecordBatch` as it is - // produced by the query plan and adds it to the list - while let Some(b) = stream.next().await.transpose()? { - // Each `RecordBatch` has one or more columns. Each column is stored as - // an `ArrayRef`. To interact with data using Rust native types we need to - // convert these `ArrayRef`s into concrete array types using APIs from - // the arrow crate. - - // In this case, we know that each batch has two columns of the Arrow - // types Int32 and Float64, so first we cast the two columns to the - // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: - let int_col: &PrimitiveArray = b.column(0).as_primitive(); - let float_col: &PrimitiveArray = b.column(1).as_primitive(); - - // With PrimitiveArrays, we can access to the values as native Rust - // types i32 and f64, and forming the desired `Data` structs - for (i, f) in int_col.values().iter().zip(float_col.values()) { - list.push(Data { - int_col: *i, - double_col: *f, - }) - } - } - - // Finally, we have the results in the list of Rust structs - let res = format!("{list:#?}"); - assert_eq!( - res, - r#"[ - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, -]"# - ); - - // Use the fields in the struct to avoid clippy complaints - let int_sum = list.iter().fold(0, |acc, x| acc + x.int_col); - let double_sum = list.iter().fold(0.0, |acc, x| acc + x.double_col); - assert_eq!(int_sum, 4); - assert_eq!(double_sum, 40.4); - - Ok(()) -} - -/// This is target struct where we want the query results. -#[derive(Debug)] -struct Data { - int_col: i32, - double_col: f64, -} diff --git a/datafusion-examples/examples/execution_monitoring/main.rs b/datafusion-examples/examples/execution_monitoring/main.rs new file mode 100644 index 0000000000000..8f80c36929ca2 --- /dev/null +++ b/datafusion-examples/examples/execution_monitoring/main.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of memory and performance management +//! +//! These examples demonstrate memory and performance management. +//! +//! ## Usage +//! ```bash +//! cargo run --example execution_monitoring -- [all|mem_pool_exec_plan|mem_pool_tracking|tracing] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `mem_pool_exec_plan` +//! (file: memory_pool_execution_plan.rs, desc: Memory-aware ExecutionPlan with spilling) +//! +//! - `mem_pool_tracking` +//! (file: memory_pool_tracking.rs, desc: Demonstrates memory tracking) +//! +//! - `tracing` +//! (file: tracing.rs, desc: Demonstrates tracing integration) + +mod memory_pool_execution_plan; +mod memory_pool_tracking; +mod tracing; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MemPoolExecPlan, + MemPoolTracking, + Tracing, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "execution_monitoring"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MemPoolExecPlan => { + memory_pool_execution_plan::memory_pool_execution_plan().await? + } + ExampleKind::MemPoolTracking => { + memory_pool_tracking::mem_pool_tracking().await? + } + ExampleKind::Tracing => tracing::tracing().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/memory_pool_execution_plan.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs similarity index 90% rename from datafusion-examples/examples/memory_pool_execution_plan.rs rename to datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs index 3258cde17625f..dc374c7e02fe5 100644 --- a/datafusion-examples/examples/memory_pool_execution_plan.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to implement custom ExecutionPlans that properly //! use memory tracking through TrackConsumersPool. //! @@ -27,8 +29,9 @@ use arrow::record_batch::RecordBatch; use arrow_schema::SchemaRef; use datafusion::common::record_batch; +use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::common::{exec_datafusion_err, internal_err}; -use datafusion::datasource::{memory::MemTable, DefaultTableSource}; +use datafusion::datasource::{DefaultTableSource, memory::MemTable}; use datafusion::error::Result; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; @@ -36,16 +39,15 @@ use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use datafusion::prelude::*; use futures::stream::{StreamExt, TryStreamExt}; -use std::any::Any; use std::fmt; use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<(), Box> { +/// Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +pub async fn memory_pool_execution_plan() -> Result<()> { println!("=== DataFusion ExecutionPlan Memory Tracking Example ===\n"); // Set up a runtime with memory tracking @@ -140,6 +142,7 @@ impl ExternalBatchBufferer { } } + #[expect(clippy::needless_pass_by_value)] fn add_batch(&mut self, batch_data: Vec) -> Result<()> { let additional_memory = batch_data.len(); @@ -196,7 +199,7 @@ impl ExternalBatchBufferer { struct BufferingExecutionPlan { schema: SchemaRef, input: Arc, - properties: PlanProperties, + properties: Arc, } impl BufferingExecutionPlan { @@ -222,15 +225,11 @@ impl ExecutionPlan for BufferingExecutionPlan { "BufferingExecutionPlan" } - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } @@ -294,7 +293,19 @@ impl ExecutionPlan for BufferingExecutionPlan { ))) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.properties.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } diff --git a/datafusion-examples/examples/memory_pool_tracking.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs similarity index 89% rename from datafusion-examples/examples/memory_pool_tracking.rs rename to datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs index d5823b1173ab3..d849a033bc66b 100644 --- a/datafusion-examples/examples/memory_pool_tracking.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to use TrackConsumersPool for memory tracking and debugging. //! //! The TrackConsumersPool provides enhanced error messages that show the top memory consumers @@ -24,11 +26,12 @@ //! //! * [`automatic_usage_example`]: Shows how to use RuntimeEnvBuilder to automatically enable memory tracking +use datafusion::error::Result; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::prelude::*; -#[tokio::main] -async fn main() -> Result<(), Box> { +/// Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages +pub async fn mem_pool_tracking() -> Result<()> { println!("=== DataFusion Memory Pool Tracking Example ===\n"); // Example 1: Automatic Usage with RuntimeEnvBuilder @@ -41,7 +44,7 @@ async fn main() -> Result<(), Box> { /// /// This shows the recommended way to use TrackConsumersPool through RuntimeEnvBuilder, /// which automatically creates a TrackConsumersPool with sensible defaults. -async fn automatic_usage_example() -> datafusion::error::Result<()> { +async fn automatic_usage_example() -> Result<()> { println!("Example 1: Automatic Usage with RuntimeEnvBuilder"); println!("------------------------------------------------"); @@ -107,7 +110,8 @@ async fn automatic_usage_example() -> datafusion::error::Result<()> { println!("✓ Expected memory limit error during data processing:"); println!("Error: {e}"); /* Example error message: - Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes + Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: ExternalSorterMerge[3]#112(can spill: false) consumed 10.0 MB, peak 10.0 MB, @@ -115,7 +119,8 @@ async fn automatic_usage_example() -> datafusion::error::Result<()> { ExternalSorter[1]#93(can spill: true) consumed 69.0 KB, peak 69.0 KB, ExternalSorter[13]#155(can spill: true) consumed 67.6 KB, peak 67.6 KB, ExternalSorter[8]#140(can spill: true) consumed 67.2 KB, peak 67.2 KB. - Error: Failed to allocate additional 10.0 MB for ExternalSorterMerge[0] with 0.0 B already allocated for this reservation - 7.1 MB remain available for the total pool + Error: Failed to allocate additional 10.0 MB for ExternalSorterMerge[0] with 0.0 B already allocated + for this reservation - 7.1 MB remain available for the total memory pool */ } } diff --git a/datafusion-examples/examples/tracing.rs b/datafusion-examples/examples/execution_monitoring/tracing.rs similarity index 82% rename from datafusion-examples/examples/tracing.rs rename to datafusion-examples/examples/execution_monitoring/tracing.rs index 334ee0f4e5686..172c1ca83b3bd 100644 --- a/datafusion-examples/examples/tracing.rs +++ b/datafusion-examples/examples/execution_monitoring/tracing.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates the tracing injection feature for the DataFusion runtime. //! Tasks spawned on new threads behave differently depending on whether a tracer is injected. //! The log output clearly distinguishes the two cases. @@ -49,20 +51,21 @@ //! 10:29:40.809 INFO main ThreadId(01) tracing: ***** WITH tracer: Non-main tasks DID inherit the `run_instrumented_query` span ***** //! ``` -use datafusion::common::runtime::{set_join_set_tracer, JoinSetTracer}; +use std::any::Any; +use std::sync::Arc; + +use datafusion::common::runtime::{JoinSetTracer, set_join_set_tracer}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion::test_util::parquet_test_data; -use futures::future::BoxFuture; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::FutureExt; -use std::any::Any; -use std::sync::Arc; -use tracing::{info, instrument, Instrument, Level, Span}; +use futures::future::BoxFuture; +use tracing::{Instrument, Level, Span, info, instrument}; -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates the tracing injection feature for the DataFusion runtime +pub async fn tracing() -> Result<()> { // Initialize tracing subscriber with thread info. tracing_subscriber::fmt() .with_thread_ids(true) @@ -73,7 +76,9 @@ async fn main() -> Result<()> { // Run query WITHOUT tracer injection. info!("***** RUNNING WITHOUT INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****"); + info!( + "***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****" + ); // Inject custom tracer so tasks run in the current span. info!("Injecting custom tracer..."); @@ -82,7 +87,9 @@ async fn main() -> Result<()> { // Run query WITH tracer injection. info!("***** RUNNING WITH INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****"); + info!( + "***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****" + ); Ok(()) } @@ -120,18 +127,27 @@ async fn run_instrumented_query() -> Result<()> { info!("Starting query execution"); let ctx = SessionContext::new(); - let test_data = parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension("alltypes_tiny_pages_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); - let table_path = format!("file://{test_data}/"); - info!("Registering table 'alltypes' from {}", table_path); - ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) - .await - .expect("Failed to register table"); + info!("Registering table 'cars' from {}", parquet_temp.path_str()?); + ctx.register_listing_table( + "cars", + parquet_temp.path_str()?, + listing_options, + None, + None, + ) + .await + .expect("Failed to register table"); - let sql = "SELECT COUNT(*), string_col FROM alltypes GROUP BY string_col"; + let sql = "SELECT COUNT(*), car, sum(speed) FROM cars GROUP BY car"; info!(sql, "Executing SQL query"); let result = ctx.sql(sql).await?.collect().await?; info!("Query complete: {} batches returned", result.len()); diff --git a/datafusion-examples/examples/extension_types/main.rs b/datafusion-examples/examples/extension_types/main.rs new file mode 100644 index 0000000000000..97c00fdcb64f8 --- /dev/null +++ b/datafusion-examples/examples/extension_types/main.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Extension type usage examples +//! +//! These examples demonstrate the API for creating and using custom extension types. +//! +//! ## Usage +//! ```bash +//! cargo run --example extension_types -- [all|temperature] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `temperature` +//! (file: temperature.rs, desc: Extension type for temperature data.) + +mod temperature; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Temperature, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "extension_types"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Temperature => { + temperature::temperature_example().await?; + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/extension_types/temperature.rs b/datafusion-examples/examples/extension_types/temperature.rs new file mode 100644 index 0000000000000..478cf5ebbf312 --- /dev/null +++ b/datafusion-examples/examples/extension_types/temperature.rs @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrowPrimitiveType, AsArray, Float32Array, Float64Array, PrimitiveArray, + RecordBatch, StringArray, +}; +use arrow::datatypes::{Float32Type, Float64Type}; +use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; +use datafusion::dataframe::DataFrame; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::SessionContext; +use datafusion_common::internal_err; +use datafusion_common::types::DFExtensionType; +use datafusion_expr::registry::{ + ExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry, +}; +use std::collections::HashMap; +use std::fmt::{Display, Write}; +use std::sync::Arc; + +/// This example demonstrates using DataFusion's extension type API to create a custom +/// extension type [`TemperatureExtensionType`] for representing different temperature units. +pub async fn temperature_example() -> Result<()> { + let ctx = create_session_context()?; + register_temperature_table(&ctx).await?; + + // Print the example table with the custom pretty-printer. + ctx.table("example").await?.show().await +} + +/// Creates the DataFusion session context with the custom extension type implementation. +fn create_session_context() -> Result { + let registry = MemoryExtensionTypeRegistry::new_empty(); + + // The registration creates a new instance of the extension type with the deserialized metadata. + let temp_registration = ExtensionTypeRegistration::new_arc( + TemperatureExtensionType::NAME, + |storage_type, metadata| { + Ok(Arc::new(TemperatureExtensionType::try_new( + storage_type, + TemperatureUnit::deserialize(metadata)?, + )?)) + }, + ); + registry.add_extension_type_registration(temp_registration)?; + + let state = SessionStateBuilder::default() + .with_extension_type_registry(Arc::new(registry)) + .build(); + Ok(SessionContext::new_with_state(state)) +} + +/// Registers the example table and returns the data frame. +async fn register_temperature_table(ctx: &SessionContext) -> Result { + let schema = example_schema(); + + let city_names = Arc::new(StringArray::from(vec![ + "Vienna", "Tokyo", "New York", "Sydney", + ])); + + // The temperature readings in different units + let celsius_temps = vec![15.1, 22.5, 18.98, 25.0]; + let fahrenheit_temps = vec![59.18, 72.5, 66.164, 77.0]; + let kelvin_temps = vec![288.25, 295.65, 292.13, 298.15]; + + let batch = RecordBatch::try_new( + schema, + vec![ + city_names, + Arc::new(Float64Array::from(celsius_temps)), + Arc::new(Float64Array::from(fahrenheit_temps)), + Arc::new(Float32Array::from(kelvin_temps)), // Demonstrate use of different storage type + ], + )?; + + ctx.register_batch("example", batch)?; + ctx.table("example").await +} + +/// The schema of the example table. +fn example_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("celsius", DataType::Float64, false) + .with_metadata(create_metadata(TemperatureUnit::Celsius)), + Field::new("fahrenheit", DataType::Float64, false) + .with_metadata(create_metadata(TemperatureUnit::Fahrenheit)), + Field::new("kelvin", DataType::Float32, false) + .with_metadata(create_metadata(TemperatureUnit::Kelvin)), + ])) +} + +/// Represents a float that semantically represents a temperature. The temperature can be one of +/// the supported [`TemperatureUnit`]s. +/// +/// The unit is realized as an additional extension type metadata and is stored alongside the +/// extension type name in the Arrow field metadata. This metadata can also be stored within files, +/// allowing DataFusion to read temperature data from, for example, Parquet files. +/// +/// The field metadata for a Celsius temperature field will look like this (serialized as JSON): +/// ```json +/// { +/// "ARROW:extension:name": "custom.temperature", +/// "ARROW:extension:metadata": "celsius" +/// } +/// ``` +/// +/// See [the official Arrow documentation](https://arrow.apache.org/docs/format/Columnar.html#extension-types) +/// for more details on the extension type mechanism. +#[derive(Debug)] +pub struct TemperatureExtensionType { + /// Extension type instances are always for a specific storage type and metadata pairing. + /// Therefore, we store the storage type. + storage_type: DataType, + /// The unit of the temperature. + temperature_unit: TemperatureUnit, +} + +impl TemperatureExtensionType { + /// The name of the extension type. + pub const NAME: &'static str = "custom.temperature"; + + /// Creates a new [`TemperatureExtensionType`]. + pub fn try_new( + storage_type: &DataType, + temperature_unit: TemperatureUnit, + ) -> Result { + match storage_type { + DataType::Float32 | DataType::Float64 => {} + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid data type: {storage_type} for temperature type, expected Float32 or Float64", + ))); + } + } + + let result = Self { + storage_type: storage_type.clone(), + temperature_unit, + }; + Ok(result) + } +} + +/// Represents the unit of a temperature reading. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TemperatureUnit { + Celsius, + Fahrenheit, + Kelvin, +} + +impl TemperatureUnit { + /// Arrow extension type metadata is encoded as a string and stored using the + /// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string + /// suffices. Extension types can store more complex metadata using serialization formats like + /// JSON. + pub fn serialize(self) -> String { + let result = match self { + TemperatureUnit::Celsius => "celsius", + TemperatureUnit::Fahrenheit => "fahrenheit", + TemperatureUnit::Kelvin => "kelvin", + }; + result.to_owned() + } + + /// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`] + /// value from the serialized string. + pub fn deserialize(value: Option<&str>) -> std::result::Result { + match value { + Some("celsius") => Ok(TemperatureUnit::Celsius), + Some("fahrenheit") => Ok(TemperatureUnit::Fahrenheit), + Some("kelvin") => Ok(TemperatureUnit::Kelvin), + Some(other) => Err(ArrowError::InvalidArgumentError(format!( + "Invalid metadata for temperature type: {other}" + ))), + None => Err(ArrowError::InvalidArgumentError( + "Temperature type requires metadata (unit)".to_owned(), + )), + } + } +} + +/// This creates a metadata map for the temperature type. Another way of writing the metadata can be +/// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait. +fn create_metadata(unit: TemperatureUnit) -> HashMap { + HashMap::from([ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + TemperatureExtensionType::NAME.to_owned(), + ), + (EXTENSION_TYPE_METADATA_KEY.to_owned(), unit.serialize()), + ]) +} + +/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`]. +/// +/// This implements the trait for customizing DataFusion. +impl DFExtensionType for TemperatureExtensionType { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + Some(self.temperature_unit.serialize()) + } + + fn create_array_formatter<'fmt>( + &self, + array: &'fmt dyn Array, + options: &FormatOptions<'fmt>, + ) -> Result>> { + match self.storage_type { + DataType::Float32 => { + let display_index = TemperatureDisplayIndex { + array: array.as_primitive::(), + null_str: options.null(), + unit: self.temperature_unit, + }; + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } + DataType::Float64 => { + let display_index = TemperatureDisplayIndex { + array: array.as_primitive::(), + null_str: options.null(), + unit: self.temperature_unit, + }; + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } + _ => internal_err!("Wrong array type for Temperature"), + } + } +} + +/// Pretty printer for temperatures. +#[derive(Debug)] +struct TemperatureDisplayIndex<'a, TNative: ArrowPrimitiveType> { + array: &'a PrimitiveArray, + null_str: &'a str, + unit: TemperatureUnit, +} + +/// Implements the custom display logic. +impl> DisplayIndex + for TemperatureDisplayIndex<'_, TNative> +{ + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + write!(f, "{}", self.null_str)?; + return Ok(()); + } + + let value = self.array.value(idx); + let suffix = match self.unit { + TemperatureUnit::Celsius => "°C", + TemperatureUnit::Fahrenheit => "°F", + TemperatureUnit::Kelvin => "K", + }; + + write!(f, "{value:.2} {suffix}")?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use insta::assert_snapshot; + + #[tokio::test] + async fn test_print_example_table() -> Result<()> { + let ctx = create_session_context()?; + let table = register_temperature_table(&ctx).await?; + + assert_snapshot!( + table.to_string().await?, + @r" + +----------+----------+------------+----------+ + | city | celsius | fahrenheit | kelvin | + +----------+----------+------------+----------+ + | Vienna | 15.10 °C | 59.18 °F | 288.25 K | + | Tokyo | 22.50 °C | 72.50 °F | 295.65 K | + | New York | 18.98 °C | 66.16 °F | 292.13 K | + | Sydney | 25.00 °C | 77.00 °F | 298.15 K | + +----------+----------+------------+----------+ + " + ); + + Ok(()) + } +} diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs similarity index 87% rename from datafusion-examples/examples/external_dependency/dataframe-to-s3.rs rename to datafusion-examples/examples/external_dependency/dataframe_to_s3.rs index e75ba5dd5328a..fdb8a3c9c051a 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::env; use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; @@ -28,14 +30,18 @@ use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; use url::Url; -/// This example demonstrates querying data from AmazonS3 and writing -/// the result of a query back to AmazonS3 -#[tokio::main] -async fn main() -> Result<()> { +/// This example demonstrates querying data from Amazon S3 and writing +/// the result of a query back to Amazon S3. +/// +/// The following environment variables must be defined: +/// +/// - AWS_ACCESS_KEY_ID +/// - AWS_SECRET_ACCESS_KEY +pub async fn dataframe_to_s3() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - //enter region and bucket to which your credentials have GET and PUT access + // enter region and bucket to which your credentials have GET and PUT access let region = ""; let bucket_name = ""; @@ -66,13 +72,13 @@ async fn main() -> Result<()> { .write_parquet(&out_path, DataFrameWriteOptions::new(), None) .await?; - //write as JSON to s3 + // write as JSON to s3 let json_out = format!("s3://{bucket_name}/json_out"); df.clone() .write_json(&json_out, DataFrameWriteOptions::new(), None) .await?; - //write as csv to s3 + // write as csv to s3 let csv_out = format!("s3://{bucket_name}/csv_out"); df.write_csv(&csv_out, DataFrameWriteOptions::new(), None) .await?; diff --git a/datafusion-examples/examples/external_dependency/main.rs b/datafusion-examples/examples/external_dependency/main.rs new file mode 100644 index 0000000000000..447e7d38bdd5b --- /dev/null +++ b/datafusion-examples/examples/external_dependency/main.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are using data from Amazon S3 examples +//! +//! These examples demonstrate how to work with data from Amazon S3. +//! +//! ## Usage +//! ```bash +//! cargo run --example external_dependency -- [all|dataframe_to_s3|query_aws_s3] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `dataframe_to_s3` +//! (file: dataframe_to_s3.rs, desc: Query DataFrames and write results to S3) +//! +//! - `query_aws_s3` +//! (file: query_aws_s3.rs, desc: Query S3-backed data using object_store) + +mod dataframe_to_s3; +mod query_aws_s3; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + DataframeToS3, + QueryAwsS3, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "external_dependency"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DataframeToS3 => dataframe_to_s3::dataframe_to_s3().await?, + ExampleKind::QueryAwsS3 => query_aws_s3::query_aws_s3().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/external_dependency/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query_aws_s3.rs similarity index 90% rename from datafusion-examples/examples/external_dependency/query-aws-s3.rs rename to datafusion-examples/examples/external_dependency/query_aws_s3.rs index da2d7e4879f99..63507bb3eed11 100644 --- a/datafusion-examples/examples/external_dependency/query-aws-s3.rs +++ b/datafusion-examples/examples/external_dependency/query_aws_s3.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; @@ -22,15 +24,13 @@ use std::env; use std::sync::Arc; use url::Url; -/// This example demonstrates querying data in an S3 bucket. +/// This example demonstrates querying data in a public S3 bucket +/// (the NYC TLC open dataset: `s3://nyc-tlc`). /// /// The following environment variables must be defined: -/// -/// - AWS_ACCESS_KEY_ID -/// - AWS_SECRET_ACCESS_KEY -/// -#[tokio::main] -async fn main() -> Result<()> { +/// - `AWS_ACCESS_KEY_ID` +/// - `AWS_SECRET_ACCESS_KEY` +pub async fn query_aws_s3() -> Result<()> { let ctx = SessionContext::new(); // the region must be set to the region where the bucket exists until the following diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml index e9c0c5b43d682..3cfa6dcf90f18 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml @@ -22,12 +22,14 @@ edition = { workspace = true } publish = false [dependencies] -abi_stable = "0.11.3" arrow = { workspace = true } datafusion = { workspace = true } datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } +[lints] +workspace = true + [lib] name = "ffi_example_table_provider" crate-type = ["cdylib", 'rlib'] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f054..7894e97f3796d 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -17,12 +17,12 @@ use std::sync::Arc; -use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; -use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; +use ffi_module_interface::TableProviderModule; fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { let end_value = start_value + num_values as i32; @@ -34,7 +34,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,14 +52,13 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new_with_ffi_codec(Arc::new(table_provider), true, None, codec) } -#[export_root_module] +#[unsafe(no_mangle)] /// This defines the entry point for using the module. -pub fn get_simple_memory_table() -> TableProviderModuleRef { +pub extern "C" fn ffi_example_get_module() -> TableProviderModule { TableProviderModule { create_table: construct_simple_table_provider, } - .leak_into_prefix() } diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml index 612a219324763..0244cb2a5ed15 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml @@ -18,9 +18,11 @@ [package] name = "ffi_module_interface" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false +[lints] +workspace = true + [dependencies] -abi_stable = "0.11.3" datafusion-ffi = { workspace = true } diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 88690e9297135..54a59c9e5d073 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -15,35 +15,17 @@ // specific language governing permissions and limitations // under the License. -use abi_stable::{ - declare_root_module_statics, - library::{LibraryError, RootModule}, - package_version_strings, - sabi_types::VersionStrings, - StableAbi, -}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; -#[repr(C)] -#[derive(StableAbi)] -#[sabi(kind(Prefix(prefix_ref = TableProviderModuleRef)))] /// This struct defines the module interfaces. It is to be shared by /// both the module loading program and library that implements the /// module. It is possible to move this definition into the loading /// program and reference it in the modules, but this example shows /// how a user may wish to separate these concerns. +#[repr(C)] pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, -} - -impl RootModule for TableProviderModuleRef { - declare_root_module_statics! {TableProviderModuleRef} - const BASE_NAME: &'static str = "ffi_example_table_provider"; - const NAME: &'static str = "ffi_example_table_provider"; - const VERSION_STRINGS: VersionStrings = package_version_strings!(); - - fn initialization(self) -> Result { - Ok(self) - } + pub create_table: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_TableProvider, } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 028a366aab1c0..e7b2dd19009b5 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -18,12 +18,15 @@ [package] name = "ffi_module_loader" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false +[lints] +workspace = true + [dependencies] -abi_stable = "0.11.3" datafusion = { workspace = true } datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } +libloading = "0.9" tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8f..0657c4a08fa86 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -18,44 +18,69 @@ use std::sync::Arc; use datafusion::{ + datasource::TableProvider, error::{DataFusionError, Result}, + execution::TaskContextProvider, prelude::SessionContext, }; - -use abi_stable::library::{development_utils::compute_library_path, RootModule}; -use datafusion_ffi::table_provider::ForeignTableProvider; -use ffi_module_interface::TableProviderModuleRef; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use ffi_module_interface::TableProviderModule; #[tokio::main] async fn main() -> Result<()> { // Find the location of the library. This is specific to the build environment, // so you will need to change the approach here based on your use case. - let target: &std::path::Path = "../../../../target/".as_ref(); - let library_path = compute_library_path::(target) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + let lib_prefix = if cfg!(target_os = "windows") { + "" + } else { + "lib" + }; + let lib_ext = if cfg!(target_os = "macos") { + "dylib" + } else if cfg!(target_os = "windows") { + "dll" + } else { + "so" + }; + + let build_type = if cfg!(debug_assertions) { + "debug" + } else { + "release" + }; + + let library_path = format!( + "../../../../target/{build_type}/{lib_prefix}ffi_example_table_provider.{lib_ext}" + ); + + // Load the library using libloading + let lib = unsafe { + libloading::Library::new(&library_path) + .map_err(|e| DataFusionError::External(Box::new(e)))? + }; - // Load the module - let table_provider_module = - TableProviderModuleRef::load_from_directory(&library_path) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + let get_module: libloading::Symbol TableProviderModule> = unsafe { + lib.get(b"ffi_example_get_module") + .map_err(|e| DataFusionError::External(Box::new(e)))? + }; + + let table_provider_module = get_module(); + + let ctx = Arc::new(SessionContext::new()); + let codec = FFI_LogicalExtensionCodec::new_default( + &(Arc::clone(&ctx) as Arc), + ); // By calling the code below, the table provided will be created within // the module's code. - let ffi_table_provider = - table_provider_module - .create_table() - .ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_table".to_string(), - ))?(); + let ffi_table_provider = (table_provider_module.create_table)(codec); // In order to access the table provider within this executable, we need to - // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + // turn it into a `TableProvider`. + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion-examples/examples/flight/flight_client.rs b/datafusion-examples/examples/flight/client.rs similarity index 83% rename from datafusion-examples/examples/flight/flight_client.rs rename to datafusion-examples/examples/flight/client.rs index ff4b5903ad884..8f6856a4e4849 100644 --- a/datafusion-examples/examples/flight/flight_client.rs +++ b/datafusion-examples/examples/flight/client.rs @@ -15,24 +15,30 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; -use tonic::transport::Endpoint; - -use datafusion::arrow::datatypes::Schema; use arrow_flight::flight_descriptor; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::utils::flight_data_to_arrow_batch; use arrow_flight::{FlightDescriptor, Ticket}; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::util::pretty; +use datafusion::prelude::SessionContext; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tonic::transport::Endpoint; /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for /// Parquet files and executing SQL queries against them on a remote server. /// This example is run along-side the example `flight_server`. -#[tokio::main] -async fn main() -> Result<(), Box> { - let testdata = datafusion::test_util::parquet_test_data(); +pub async fn client() -> Result<(), Box> { + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Create Flight client let endpoint = Endpoint::new("http://localhost:50051")?; @@ -43,7 +49,7 @@ async fn main() -> Result<(), Box> { let request = tonic::Request::new(FlightDescriptor { r#type: flight_descriptor::DescriptorType::Path as i32, cmd: Default::default(), - path: vec![format!("{testdata}/alltypes_plain.parquet")], + path: vec![format!("{}", parquet_temp.path_str()?)], }); let schema_result = client.get_schema(request).await?.into_inner(); @@ -52,7 +58,7 @@ async fn main() -> Result<(), Box> { // Call do_get to execute a SQL query and receive results let request = tonic::Request::new(Ticket { - ticket: "SELECT id FROM alltypes_plain".into(), + ticket: "SELECT car FROM cars".into(), }); let mut stream = client.do_get(request).await?.into_inner(); diff --git a/datafusion-examples/examples/flight/main.rs b/datafusion-examples/examples/flight/main.rs new file mode 100644 index 0000000000000..426e806486f70 --- /dev/null +++ b/datafusion-examples/examples/flight/main.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Arrow Flight Examples +//! +//! These examples demonstrate Arrow Flight usage. +//! +//! ## Usage +//! ```bash +//! cargo run --example flight -- [all|client|server|sql_server] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! Note: The Flight server must be started in a separate process +//! before running the `client` example. Therefore, running `all` will +//! not produce a full server+client workflow automatically. +//! +//! - `client` +//! (file: client.rs, desc: Execute SQL queries via Arrow Flight protocol) +//! +//! - `server` +//! (file: server.rs, desc: Run DataFusion server accepting FlightSQL/JDBC queries) +//! +//! - `sql_server` +//! (file: sql_server.rs, desc: Standalone SQL server for JDBC clients) + +mod client; +mod server; +mod sql_server; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +/// The `all` option cannot run all examples end-to-end because the +/// `server` example must run in a separate process before the `client` +/// example can connect. +/// Therefore, `all` only iterates over individually runnable examples. +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Client, + Server, + SqlServer, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "flight"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<(), Box> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Client => client::client().await?, + ExampleKind::Server => server::server().await?, + ExampleKind::SqlServer => sql_server::sql_server().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/server.rs similarity index 88% rename from datafusion-examples/examples/flight/flight_server.rs rename to datafusion-examples/examples/flight/server.rs index 22265e415fbdb..b73c81dd7d2c3 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/server.rs @@ -15,25 +15,26 @@ // specific language governing permissions and limitations // under the License. -use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; +//! See `main.rs` for how to run it. + use std::sync::Arc; +use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; +use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, +}; use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; -use datafusion::prelude::*; - -use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, -}; - #[derive(Clone)] pub struct FlightServiceImpl {} @@ -83,16 +84,21 @@ impl FlightService for FlightServiceImpl { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| { + Status::internal(format!("Error writing csv to parquet: {e}")) + })?; + let parquet_path = parquet_temp.path_str().map_err(|e| { + Status::internal(format!("Error getting parquet path: {e}")) + })?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(to_tonic_err)?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(to_tonic_err)?; // create the DataFrame let df = ctx.sql(sql).await.map_err(to_tonic_err)?; @@ -187,6 +193,7 @@ impl FlightService for FlightServiceImpl { } } +#[expect(clippy::needless_pass_by_value)] fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { Status::internal(format!("{e:?}")) } @@ -194,8 +201,7 @@ fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for /// Parquet files and executing SQL queries against them on a remote server. /// This example is run along-side the example `flight_client`. -#[tokio::main] -async fn main() -> Result<(), Box> { +pub async fn server() -> Result<(), Box> { let addr = "0.0.0.0:50051".parse()?; let service = FlightServiceImpl {}; diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/sql_server.rs similarity index 94% rename from datafusion-examples/examples/flight/flight_sql_server.rs rename to datafusion-examples/examples/flight/sql_server.rs index c35debec7d712..e55aaa7250ea7 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/sql_server.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::pin::Pin; +use std::sync::Arc; + use arrow::array::{ArrayRef, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::ipc::writer::IpcWriteOptions; @@ -36,12 +41,11 @@ use arrow_flight::{ use dashmap::DashMap; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::{Stream, StreamExt, TryStreamExt}; use log::info; use mimalloc::MiMalloc; use prost::Message; -use std::pin::Pin; -use std::sync::Arc; use tonic::metadata::MetadataValue; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -68,9 +72,7 @@ macro_rules! status { /// /// Based heavily on Ballista's implementation: https://github.com/apache/datafusion-ballista/blob/main/ballista/scheduler/src/flight_sql.rs /// and the example in arrow-rs: https://github.com/apache/arrow-rs/blob/master/arrow-flight/examples/flight_sql_server.rs -/// -#[tokio::main] -async fn main() -> Result<(), Box> { +pub async fn sql_server() -> Result<(), Box> { env_logger::init(); let addr = "0.0.0.0:50051".parse()?; let service = FlightSqlServiceImpl { @@ -100,22 +102,24 @@ impl FlightSqlServiceImpl { .with_information_schema(true); let ctx = Arc::new(SessionContext::new_with_config(session_config)); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| status!("Error writing csv to parquet", e))?; + let parquet_path = parquet_temp + .path_str() + .map_err(|e| status!("Error getting parquet path", e))?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(|e| status!("Error registering table", e))?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(|e| status!("Error registering table", e))?; self.contexts.insert(uuid.clone(), ctx); Ok(uuid) } - #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -141,7 +145,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -150,7 +153,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -198,13 +200,11 @@ impl FlightSqlServiceImpl { .unwrap() } - #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } - #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) @@ -416,7 +416,9 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result<(), Status> { let handle = std::str::from_utf8(&handle.prepared_statement_handle); if let Ok(handle) = handle { - info!("do_action_close_prepared_statement: removing plan and results for {handle}"); + info!( + "do_action_close_prepared_statement: removing plan and results for {handle}" + ); let _ = self.remove_plan(handle); let _ = self.remove_result(handle); } diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/proto/composed_extension_codec.rs similarity index 84% rename from datafusion-examples/examples/composed_extension_codec.rs rename to datafusion-examples/examples/proto/composed_extension_codec.rs index 57f2c370413aa..ae9503dd87b19 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/proto/composed_extension_codec.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to compose multiple PhysicalExtensionCodecs //! //! This can be helpful when an Execution plan tree has different nodes from different crates @@ -30,12 +32,12 @@ //! DeltaScan //! ``` -use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use datafusion::common::internal_err; use datafusion::common::Result; +use datafusion::common::internal_err; +use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::execution::TaskContext; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; @@ -44,8 +46,8 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; -#[tokio::main] -async fn main() { +/// Example of using multiple extension codecs for serialization / deserialization +pub async fn composed_extension_codec() -> Result<()> { // build execution plan that has both types of nodes // // Note each node requires a different `PhysicalExtensionCodec` to decode @@ -66,16 +68,16 @@ async fn main() { protobuf::PhysicalPlanNode::try_from_physical_plan( exec_plan.clone(), &composed_codec, - ) - .expect("to proto"); + )?; // deserialize proto back to execution plan - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) - .expect("from proto"); + let result_exec_plan: Arc = + proto.try_into_physical_plan(&ctx.task_ctx(), &composed_codec)?; // assert that the original and deserialized execution plans are equal assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + + Ok(()) } /// This example has two types of nodes: `ParentExec` and `ChildExec` which can only @@ -100,11 +102,7 @@ impl ExecutionPlan for ParentExec { "ParentExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -126,6 +124,15 @@ impl ExecutionPlan for ParentExec { ) -> Result { unreachable!() } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } /// A PhysicalExtensionCodec that can serialize and deserialize ParentExec @@ -149,7 +156,7 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { } fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - if node.as_any().downcast_ref::().is_some() { + if node.is::() { buf.extend_from_slice("ParentExec".as_bytes()); Ok(()) } else { @@ -176,11 +183,7 @@ impl ExecutionPlan for ChildExec { "ChildExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -202,6 +205,15 @@ impl ExecutionPlan for ChildExec { ) -> Result { unreachable!() } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } /// A PhysicalExtensionCodec that can serialize and deserialize ChildExec @@ -223,7 +235,7 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { } fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - if node.as_any().downcast_ref::().is_some() { + if node.is::() { buf.extend_from_slice("ChildExec".as_bytes()); Ok(()) } else { diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs new file mode 100644 index 0000000000000..26d246b2efca8 --- /dev/null +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the +//! `PhysicalProtoConverterExtension` trait's interception methods to +//! implement expression deduplication during deserialization. +//! +//! This pattern is inspired by PR #18192, which introduces expression caching +//! to reduce memory usage when deserializing plans with duplicate expressions. +//! +//! The key insight is that identical expressions serialize to identical protobuf bytes. +//! By caching deserialized expressions keyed by their protobuf bytes, we can: +//! 1. Return the same Arc for duplicate expressions +//! 2. Reduce memory allocation during deserialization +//! 3. Enable downstream optimizations that rely on Arc pointer equality +//! +//! This demonstrates the decorator pattern enabled by +//! `PhysicalProtoConverterExtension`, where physical-expression +//! serialization and deserialization route through converter hooks. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::{BinaryExpr, col}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, PhysicalPlanDecodeContext, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +/// Example showing how to implement expression deduplication using the codec decorator pattern. +/// +/// This demonstrates: +/// 1. Creating a CachingCodec that caches expressions by their protobuf bytes +/// 2. Intercepting deserialization to return cached Arcs for duplicate expressions +/// 3. Verifying that duplicate expressions share the same Arc after deserialization +/// +/// Deduplication is keyed by the protobuf bytes representing the expression, +/// in reality deduplication could be done based on e.g. the pointer address of the +/// serialized expression in memory, but this is simpler to demonstrate. +/// +/// In this case our expression is trivial and just for demonstration purposes. +/// In real scenarios, expressions can be much more complex, e.g. a large InList +/// expression could be megabytes in size, so deduplication can save significant memory +/// in addition to more correctly representing the original plan structure. +pub async fn expression_deduplication() -> Result<()> { + println!("=== Expression Deduplication Example ===\n"); + + // Create a schema for our test expressions + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + // Step 1: Create expressions with duplicates + println!("Step 1: Creating expressions with duplicates..."); + + // Create expression: col("a") + let a = col("a", &schema)?; + + // Create a clone to show duplicates + let a_clone = Arc::clone(&a); + + // Combine: a OR a_clone + let combined_expr = + Arc::new(BinaryExpr::new(a, Operator::Or, a_clone)) as Arc; + println!(" Created expression: a OR a with duplicates"); + println!(" Note: a appears twice in the expression tree\n"); + // Step 2: Create a filter plan with this expression + println!("Step 2: Creating physical plan with the expression..."); + + let input = Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))); + let filter_plan: Arc = + Arc::new(FilterExec::try_new(combined_expr, input)?); + + println!(" Created FilterExec with duplicate sub-expressions\n"); + + // Step 3: Serialize with the caching codec + println!("Step 3: Serializing plan..."); + + let extension_codec = DefaultPhysicalExtensionCodec {}; + let caching_converter = CachingCodec::new(); + let proto = + caching_converter.execution_plan_to_proto(&filter_plan, &extension_codec)?; + + // Serialize to bytes + let mut bytes = Vec::new(); + proto.encode(&mut bytes).unwrap(); + println!(" Serialized plan to {} bytes\n", bytes.len()); + + // Step 4: Deserialize with the caching codec + println!("Step 4: Deserializing plan with CachingCodec..."); + + let ctx = SessionContext::new(); + let deserialized_plan = proto.try_into_physical_plan_with_converter( + &ctx.task_ctx(), + &extension_codec, + &caching_converter, + )?; + + // Step 5: check that we deduplicated expressions + println!("Step 5: Checking for deduplicated expressions..."); + let Some(filter_exec) = deserialized_plan.downcast_ref::() else { + panic!("Deserialized plan is not a FilterExec"); + }; + let predicate = Arc::clone(filter_exec.predicate()); + let binary_expr = predicate + .downcast_ref::() + .expect("Predicate is not a BinaryExpr"); + let left = &binary_expr.left(); + let right = &binary_expr.right(); + // Check if left and right point to the same Arc + let deduplicated = Arc::ptr_eq(left, right); + if deduplicated { + println!(" Success: Duplicate expressions were deduplicated!"); + println!( + " Cache Stats: hits={}, misses={}", + caching_converter.stats.read().unwrap().cache_hits, + caching_converter.stats.read().unwrap().cache_misses, + ); + } else { + println!(" Failure: Duplicate expressions were NOT deduplicated."); + } + + Ok(()) +} + +// ============================================================================ +// CachingCodec - Implements expression deduplication +// ============================================================================ + +/// Statistics for cache performance monitoring +#[derive(Debug, Default)] +struct CacheStats { + cache_hits: usize, + cache_misses: usize, +} + +/// A codec that caches deserialized expressions to enable deduplication. +/// +/// When deserializing, if we've already seen the same protobuf bytes, +/// we return the cached Arc instead of creating a new allocation. +#[derive(Debug, Default)] +struct CachingCodec { + /// Cache mapping protobuf bytes -> deserialized expression + expr_cache: RwLock, Arc>>, + /// Statistics for demonstration + stats: RwLock, +} + +impl CachingCodec { + fn new() -> Self { + Self::default() + } +} + +impl PhysicalExtensionCodec for CachingCodec { + // Required: decode custom extension nodes + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + datafusion::common::not_impl_err!("No custom extension nodes") + } + + // Required: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + datafusion::common::not_impl_err!("No custom extension nodes") + } +} + +impl PhysicalProtoConverterExtension for CachingCodec { + fn proto_to_execution_plan( + &self, + proto: &PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + self.default_proto_to_execution_plan(proto, ctx) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // CACHING IMPLEMENTATION: Intercept expression deserialization + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + // Create cache key from protobuf bytes + let mut key = Vec::new(); + proto.encode(&mut key).map_err(|e| { + datafusion::error::DataFusionError::Internal(format!( + "Failed to encode proto for cache key: {e}" + )) + })?; + + // Check cache first + { + let cache = self.expr_cache.read().unwrap(); + if let Some(cached) = cache.get(&key) { + // Cache hit! Update stats and return cached Arc + let mut stats = self.stats.write().unwrap(); + stats.cache_hits += 1; + return Ok(Arc::clone(cached)); + } + } + + // Cache miss - deserialize and store + let expr = parse_physical_expr_with_converter(proto, input_schema, ctx, self)?; + + // Store in cache + { + let mut cache = self.expr_cache.write().unwrap(); + cache.insert(key, Arc::clone(&expr)); + let mut stats = self.stats.write().unwrap(); + stats.cache_misses += 1; + } + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs new file mode 100644 index 0000000000000..3f525b5d46afa --- /dev/null +++ b/datafusion-examples/examples/proto/main.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Examples demonstrating DataFusion's plan serialization via the `datafusion-proto` crate +//! +//! These examples show how to use multiple extension codecs for serialization / deserialization. +//! +//! ## Usage +//! ```bash +//! cargo run --example proto -- [all|composed_extension_codec|expression_deduplication] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `composed_extension_codec` +//! (file: composed_extension_codec.rs, desc: Use multiple extension codecs for serialization/deserialization) +//! +//! - `expression_deduplication` +//! (file: expression_deduplication.rs, desc: Example of expression caching/deduplication using the codec decorator pattern) + +mod composed_extension_codec; +mod expression_deduplication; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + ComposedExtensionCodec, + ExpressionDeduplication, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "proto"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::ComposedExtensionCodec => { + composed_extension_codec::composed_extension_codec().await? + } + ExampleKind::ExpressionDeduplication => { + expression_deduplication::expression_deduplication().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/query_planning/analyzer_rule.rs similarity index 97% rename from datafusion-examples/examples/analyzer_rule.rs rename to datafusion-examples/examples/query_planning/analyzer_rule.rs index cb81cd167a88b..a86f5cdd2a5e3 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/query_planning/analyzer_rule.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::common::Result; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::Result; -use datafusion::logical_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder, col, lit}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::prelude::SessionContext; use std::sync::{Arc, Mutex}; @@ -35,8 +37,7 @@ use std::sync::{Arc, Mutex}; /// level access control scheme by introducing a filter to the query. /// /// See [optimizer_rule.rs] for an example of a optimizer rule -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn analyzer_rule() -> Result<()> { // AnalyzerRules run before OptimizerRules. // // DataFusion includes several built in AnalyzerRules for tasks such as type diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/query_planning/expr_api.rs similarity index 95% rename from datafusion-examples/examples/expr_api.rs rename to datafusion-examples/examples/query_planning/expr_api.rs index 56f960870e58a..c087019c687c5 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/query_planning/expr_api.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{BooleanArray, Int32Array, Int8Array}; +use arrow::array::{BooleanArray, Int8Array, Int32Array}; use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; @@ -35,7 +37,7 @@ use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; use datafusion::optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; +use datafusion::physical_expr::{AnalysisContext, ExprBoundaries, analyze}; use datafusion::prelude::*; /// This example demonstrates the DataFusion [`Expr`] API. @@ -55,8 +57,7 @@ use datafusion::prelude::*; /// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] /// 6. Get the types of the expressions: [`expression_type_demo`] /// 7. Apply type coercion to expressions: [`type_coercion_demo`] -#[tokio::main] -async fn main() -> Result<()> { +pub async fn expr_api() -> Result<()> { // The easiest way to do create expressions is to use the // "fluent"-style API: let expr = col("a") + lit(5); @@ -174,8 +175,10 @@ fn simplify_demo() -> Result<()> { // the ExecutionProps carries information needed to simplify // expressions, such as the current time (to evaluate `now()` // correctly) - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::builder() + .with_schema(schema) + .with_current_time() + .build(); let simplifier = ExprSimplifier::new(context); // And then call the simplify_expr function: @@ -190,12 +193,15 @@ fn simplify_demo() -> Result<()> { // here are some other examples of what DataFusion is capable of let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema.clone()); + let context = SimplifyContext::builder() + .with_schema(Arc::clone(&schema)) + .with_current_time() + .build(); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification // i + 1 + 2 => i + 3 - // (note this is not done if the expr is (col("i") + (lit(1) + lit(2)))) + // (note this is not done if the expr is (col("i") + lit(1) + lit(2))) assert_eq!( simplifier.simplify(col("i") + (lit(1) + lit(2)))?, col("i") + lit(3) @@ -257,7 +263,7 @@ fn range_analysis_demo() -> Result<()> { // You can provide DataFusion any known boundaries on the values of `date` // (for example, maybe you know you only have data up to `2020-09-15`), but // in this case, let's say we don't know any boundaries beforehand so we use - // `try_new_unknown` + // `try_new_unbounded` let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; // Now, we invoke the analysis code to perform the range analysis @@ -302,6 +308,7 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; // We can then build our expression boundaries from the column statistics @@ -342,9 +349,11 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { // // (a' - b' + 1) / (a - b) // (10000 - 5000 + 1) / (10000 - 1) - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity)) + ); Ok(()) } @@ -369,6 +378,7 @@ fn boundary_analysis_in_conjunctions_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(14))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; let initial_boundaries = @@ -414,9 +424,11 @@ fn boundary_analysis_in_conjunctions_demo() -> Result<()> { // // Granted a column such as age will more likely follow a Normal distribution // as such our selectivity estimation will not be as good as it can. - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity)) + ); // The above example was a good way to look at how we can derive better // interval and get a lower selectivity during boundary analysis. @@ -532,10 +544,11 @@ fn type_coercion_demo() -> Result<()> { let physical_expr = datafusion::physical_expr::create_physical_expr(&expr, &df_schema, &props)?; let e = physical_expr.evaluate(&batch).unwrap_err(); - assert!(e - .find_root() - .to_string() - .contains("Invalid comparison operation: Int8 > Int32")); + assert!( + e.find_root() + .to_string() + .contains("Invalid comparison operation: Int8 > Int32") + ); // 1. Type coercion with `SessionContext::create_physical_expr` which implicitly applies type coercion before constructing the physical expr. let physical_expr = @@ -543,7 +556,10 @@ fn type_coercion_demo() -> Result<()> { assert!(physical_expr.evaluate(&batch).is_ok()); // 2. Type coercion with `ExprSimplifier::coerce`. - let context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema.clone())); + let context = SimplifyContext::builder() + .with_schema(Arc::new(df_schema.clone())) + .with_current_time() + .build(); let simplifier = ExprSimplifier::new(context); let coerced_expr = simplifier.coerce(expr.clone(), &df_schema)?; let physical_expr = datafusion::physical_expr::create_physical_expr( diff --git a/datafusion-examples/examples/query_planning/main.rs b/datafusion-examples/examples/query_planning/main.rs new file mode 100644 index 0000000000000..d3f99aedceb3d --- /dev/null +++ b/datafusion-examples/examples/query_planning/main.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are all internal mechanics of the query planning and optimization layers +//! +//! These examples demonstrate internal mechanics of the query planning and optimization layers. +//! +//! ## Usage +//! ```bash +//! cargo run --example query_planning -- [all|analyzer_rule|expr_api|optimizer_rule|parse_sql_expr|plan_to_sql|planner_api|pruning|thread_pools] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `analyzer_rule` +//! (file: analyzer_rule.rs, desc: Custom AnalyzerRule to change query semantics) +//! +//! - `expr_api` +//! (file: expr_api.rs, desc: Create, execute, analyze, and coerce Exprs) +//! +//! - `optimizer_rule` +//! (file: optimizer_rule.rs, desc: Replace predicates via a custom OptimizerRule) +//! +//! - `parse_sql_expr` +//! (file: parse_sql_expr.rs, desc: Parse SQL into DataFusion Expr) +//! +//! - `plan_to_sql` +//! (file: plan_to_sql.rs, desc: Generate SQL from expressions or plans) +//! +//! - `planner_api` +//! (file: planner_api.rs, desc: APIs for logical and physical plan manipulation) +//! +//! - `pruning` +//! (file: pruning.rs, desc: Use pruning to skip irrelevant files) +//! +//! - `thread_pools` +//! (file: thread_pools.rs, desc: Configure custom thread pools for DataFusion execution) + +mod analyzer_rule; +mod expr_api; +mod optimizer_rule; +mod parse_sql_expr; +mod plan_to_sql; +mod planner_api; +mod pruning; +mod thread_pools; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AnalyzerRule, + ExprApi, + OptimizerRule, + ParseSqlExpr, + PlanToSql, + PlannerApi, + Pruning, + ThreadPools, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "query_planning"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AnalyzerRule => analyzer_rule::analyzer_rule().await?, + ExampleKind::ExprApi => expr_api::expr_api().await?, + ExampleKind::OptimizerRule => optimizer_rule::optimizer_rule().await?, + ExampleKind::ParseSqlExpr => parse_sql_expr::parse_sql_expr().await?, + ExampleKind::PlanToSql => plan_to_sql::plan_to_sql_examples().await?, + ExampleKind::PlannerApi => planner_api::planner_api().await?, + ExampleKind::Pruning => pruning::pruning().await?, + ExampleKind::ThreadPools => thread_pools::thread_pools().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/query_planning/optimizer_rule.rs similarity index 97% rename from datafusion-examples/examples/optimizer_rule.rs rename to datafusion-examples/examples/query_planning/optimizer_rule.rs index 9c137b67432c5..67683b7fe2827 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/query_planning/optimizer_rule.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{assert_batches_eq, Result, ScalarValue}; +use datafusion::common::{Result, ScalarValue, assert_batches_eq}; use datafusion::logical_expr::{ BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -26,7 +28,6 @@ use datafusion::logical_expr::{ use datafusion::optimizer::ApplyOrder; use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; use datafusion::prelude::SessionContext; -use std::any::Any; use std::sync::Arc; /// This example demonstrates how to add your own [`OptimizerRule`] @@ -37,8 +38,7 @@ use std::sync::Arc; /// /// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for /// changing plan semantics. -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn optimizer_rule() -> Result<()> { // DataFusion includes many built in OptimizerRules for tasks such as outer // to inner join conversion and constant folding. // @@ -189,10 +189,6 @@ impl MyEq { } impl ScalarUDFImpl for MyEq { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "my_eq" } diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/query_planning/parse_sql_expr.rs similarity index 68% rename from datafusion-examples/examples/parse_sql_expr.rs rename to datafusion-examples/examples/query_planning/parse_sql_expr.rs index 5387e7c4a05dc..74072b8480f99 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/query_planning/parse_sql_expr.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; +use datafusion::common::ScalarValue; use datafusion::logical_expr::{col, lit}; use datafusion::sql::unparser::Unparser; use datafusion::{ @@ -24,6 +27,7 @@ use datafusion::{ error::Result, prelude::{ParquetReadOptions, SessionContext}, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic parsing of SQL expressions using /// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API. @@ -32,17 +36,15 @@ use datafusion::{ /// The code in this example shows how to: /// /// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical -/// expression using a schema at [`SessionContext`]. +/// expression using a schema at [`SessionContext`]. /// /// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression -/// using a schema at [`DataFrame`]. +/// using a schema at [`DataFrame`]. /// /// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame. /// /// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`]. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parse_sql_expr() -> Result<()> { // See how to evaluate expressions simple_session_context_parse_sql_expr_demo()?; simple_dataframe_parse_sql_expr_demo().await?; @@ -70,18 +72,19 @@ fn simple_session_context_parse_sql_expr_demo() -> Result<()> { /// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`]. async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { - let sql = "int_col < 5 OR double_col = 8.0"; - let expr = col("int_col") - .lt(lit(5_i64)) - .or(col("double_col").eq(lit(8.0_f64))); + let sql = "car = 'red' OR speed > 1.0"; + let expr = col("car") + .eq(lit(ScalarValue::Utf8(Some("red".to_string())))) + .or(col("speed").gt(lit(1.0_f64))); let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -93,39 +96,37 @@ async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { async fn query_parquet_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? // Directly parsing the SQL text into a sort expression is not supported yet, so // construct it programmatically - .sort(vec![col("double_col").sort(false, false)])? + .sort(vec![col("car").sort(false, false)])? .limit(0, Some(1))?; let result = df.collect().await?; assert_batches_eq!( &[ - "+------------+-------------+", - "| double_col | sum_int_col |", - "+------------+-------------+", - "| 10.1 | 4 |", - "+------------+-------------+", + "+-----+--------------------+", + "| car | sum_speed |", + "+-----+--------------------+", + "| red | 162.49999999999997 |", + "+-----+--------------------+" ], &result ); @@ -135,15 +136,16 @@ async fn query_parquet_demo() -> Result<()> { /// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_demo() -> Result<()> { - let sql = "((int_col < 5) OR (double_col = 8))"; + let sql = "((car = 'red') OR (speed > 1.0))"; let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -158,7 +160,7 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { // difference in precedence rules between DataFusion and target engines. let unparser = Unparser::default().with_pretty(true); - let pretty = "int_col < 5 OR double_col = 8"; + let pretty = "car = 'red' OR speed > 1.0"; let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); assert_eq!(pretty, pretty_round_trip_sql); diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/query_planning/plan_to_sql.rs similarity index 77% rename from datafusion-examples/examples/plan_to_sql.rs rename to datafusion-examples/examples/query_planning/plan_to_sql.rs index 54483b143a169..86aebbc0b2c33 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/query_planning/plan_to_sql.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::fmt; +use std::sync::Arc; + use datafusion::common::DFSchemaRef; +use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::Statement; use datafusion::logical_expr::{ @@ -32,9 +38,8 @@ use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparse use datafusion::sql::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; -use datafusion::sql::unparser::{plan_to_sql, Unparser}; -use std::fmt; -use std::sync::Arc; +use datafusion::sql::unparser::{Unparser, plan_to_sql}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic construction of SQL strings using /// the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. @@ -43,28 +48,26 @@ use std::sync::Arc; /// The code in this example shows how to: /// /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with -/// fluent API and convert to sql suitable for passing to another database +/// fluent API and convert to sql suitable for passing to another database /// /// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without extra parentheses, -/// suitable for displaying to humans +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple -/// expression [`Exprs`] with fluent API and convert to sql escaping column -/// names in MySQL style. +/// expression [`Exprs`] with fluent API and convert to sql escaping column +/// names in MySQL style. /// /// 4. [`simple_plan_to_sql_demo`]: Create a simple logical plan using the -/// DataFrames API and convert to sql string. +/// DataFrames API and convert to sql string. /// /// 5. [`round_trip_plan_to_sql_demo`]: Create a logical plan from a SQL string, modify it using the -/// DataFrames API and convert it back to a sql string. +/// DataFrames API and convert it back to a sql string. /// /// 6. [`unparse_my_logical_plan_as_statement`]: Create a custom logical plan and unparse it as a statement. /// /// 7. [`unparse_my_logical_plan_as_subquery`]: Create a custom logical plan and unparse it as a subquery. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn plan_to_sql_examples() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; simple_expr_to_pretty_sql_demo()?; @@ -114,21 +117,21 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { async fn simple_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])?; + .select_columns(&["car", "speed", "time"])?; // Convert the data frame to a SQL string let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) @@ -139,35 +142,35 @@ async fn simple_plan_to_sql_demo() -> Result<()> { async fn round_trip_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // register parquet file with the execution context ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), + "cars", + parquet_temp.path_str()?, ParquetReadOptions::default(), ) .await?; // create a logical plan from a SQL string and then programmatically add new filters + // select car, speed, time from cars where speed > 1 and car = 'red' let df = ctx // Use SQL to read some data from the parquet file - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain", - ) + .sql("SELECT car, speed, time FROM cars") .await? - // Add id > 1 and tinyint_col < double_col filter + // Add speed > 1 and car = 'red' filter .filter( - col("id") + col("speed") .gt(lit(1)) - .and(col("tinyint_col").lt(col("double_col"))), + .and(col("car").eq(lit(ScalarValue::Utf8(Some("red".to_string()))))), )?; let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# + r#"SELECT cars.car, cars.speed, cars."time" FROM cars WHERE ((cars.speed > 1) AND (cars.car = 'red'))"# ); Ok(()) @@ -211,6 +214,7 @@ impl UserDefinedLogicalNodeCore for MyLogicalPlan { } struct PlanToStatement {} + impl UserDefinedLogicalNodeUnparser for PlanToStatement { fn unparse_to_statement( &self, @@ -231,14 +235,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement { /// It can be unparse as a statement that reads from the same parquet file. async fn unparse_my_logical_plan_as_statement() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -249,7 +254,7 @@ async fn unparse_my_logical_plan_as_statement() -> Result<()> { let sql = unparser.plan_to_sql(&my_plan)?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) } @@ -284,14 +289,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { /// It can be unparse as a subquery that reads from the same parquet file, with some columns projected. async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -299,8 +305,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let my_plan = LogicalPlan::Extension(Extension { node }); let plan = LogicalPlanBuilder::from(my_plan) .project(vec![ - col("id").alias("my_id"), - col("int_col").alias("my_int"), + col("car").alias("my_car"), + col("speed").alias("my_speed"), ])? .build()?; let unparser = @@ -308,8 +314,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let sql = unparser.plan_to_sql(&plan)?.to_string(); assert_eq!( sql, - "SELECT \"?table?\".id AS my_id, \"?table?\".int_col AS my_int FROM \ - (SELECT \"?table?\".id, \"?table?\".int_col, \"?table?\".double_col, \"?table?\".date_string_col FROM \"?table?\")", + "SELECT \"?table?\".car AS my_car, \"?table?\".speed AS my_speed FROM \ + (SELECT \"?table?\".car, \"?table?\".speed, \"?table?\".\"time\" FROM \"?table?\")", ); Ok(()) } diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/query_planning/planner_api.rs similarity index 86% rename from datafusion-examples/examples/planner_api.rs rename to datafusion-examples/examples/query_planning/planner_api.rs index 55aec7b0108a4..8b2c09f4aecba 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/query_planning/planner_api.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::displayable; use datafusion::physical_planner::DefaultPhysicalPlanner; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the process of converting logical plan /// into physical execution plans using DataFusion. @@ -32,29 +35,26 @@ use datafusion::prelude::*; /// physical plan: /// - Via the combined `create_physical_plan` API. /// - Utilizing the analyzer, optimizer, and query planner APIs separately. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn planner_api() -> Result<()> { // Set up a DataFusion context and load a Parquet file let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // Construct the input logical plan using DataFrame API let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? .limit(0, Some(1))?; let logical_plan = df.logical_plan().clone(); diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/query_planning/pruning.rs similarity index 96% rename from datafusion-examples/examples/pruning.rs rename to datafusion-examples/examples/query_planning/pruning.rs index 9a61789662cdd..7fdc4a7952d68 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/query_planning/pruning.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashSet; use std::sync::Arc; @@ -22,6 +24,7 @@ use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; +use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_optimizer::pruning::PruningPredicate; @@ -40,8 +43,7 @@ use datafusion::prelude::*; /// one might do as part of a higher level storage engine. See /// `parquet_index.rs` for an example that uses pruning in the context of an /// individual query. -#[tokio::main] -async fn main() { +pub async fn pruning() -> Result<()> { // In this example, we'll use the PruningPredicate to determine if // the expression `x = 5 AND y = 10` can never be true based on statistics @@ -69,7 +71,7 @@ async fn main() { let predicate = create_pruning_predicate(expr, &my_catalog.schema); // Evaluate the predicate for the three files in the catalog - let prune_results = predicate.prune(&my_catalog).unwrap(); + let prune_results = predicate.prune(&my_catalog)?; println!("Pruning results: {prune_results:?}"); // The result is a `Vec` of bool values, one for each file in the catalog @@ -93,6 +95,8 @@ async fn main() { false ] ); + + Ok(()) } /// A simple model catalog that has information about the three files that store @@ -170,7 +174,7 @@ impl PruningStatistics for MyCatalog { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self) -> Option { // In this example, we know nothing about the number of rows in each file None } @@ -186,6 +190,7 @@ impl PruningStatistics for MyCatalog { } } +#[expect(clippy::needless_pass_by_value)] fn create_pruning_predicate(expr: Expr, schema: &SchemaRef) -> PruningPredicate { let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap(); let props = ExecutionProps::new(); diff --git a/datafusion-examples/examples/thread_pools.rs b/datafusion-examples/examples/query_planning/thread_pools.rs similarity index 96% rename from datafusion-examples/examples/thread_pools.rs rename to datafusion-examples/examples/query_planning/thread_pools.rs index bba56b2932abc..2ff73a77c4024 100644 --- a/datafusion-examples/examples/thread_pools.rs +++ b/datafusion-examples/examples/query_planning/thread_pools.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to use separate thread pools (tokio [`Runtime`]))s to //! run the IO and CPU intensive parts of DataFusion plans. //! @@ -35,15 +37,17 @@ //! //! [Architecture section]: https://docs.rs/datafusion/latest/datafusion/index.html#thread-scheduling-cpu--io-thread-pools-and-tokio-runtimes +use std::sync::Arc; + use arrow::util::pretty::pretty_format_batches; use datafusion::common::runtime::JoinSet; use datafusion::error::Result; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::stream::StreamExt; use object_store::client::SpawnedReqwestConnector; use object_store::http::HttpBuilder; -use std::sync::Arc; use tokio::runtime::Handle; use tokio::sync::Notify; use url::Url; @@ -64,15 +68,16 @@ use url::Url; /// when using Rust libraries such as `tonic`. Using a separate `Runtime` for /// CPU bound tasks will often be simpler in larger applications, even though it /// makes this example slightly more complex. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn thread_pools() -> Result<()> { // The first two examples read local files. Enabling the URL table feature // lets us treat filenames as tables in SQL. let ctx = SessionContext::new().enable_url_table(); - let sql = format!( - "SELECT * FROM '{}/alltypes_plain.parquet'", - datafusion::test_util::parquet_test_data() - ); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + let sql = format!("SELECT * FROM '{}'", parquet_temp.path_str()?); // Run a query on the current runtime. Calling `await` means the future // (in this case the `async` function and all spawned work in DataFusion @@ -121,7 +126,7 @@ async fn same_runtime(ctx: &SessionContext, sql: &str) -> Result<()> { // Executing the plan using this pattern intermixes any IO and CPU intensive // work on same Runtime while let Some(batch) = stream.next().await { - println!("{}", pretty_format_batches(&[batch?]).unwrap()); + println!("{}", pretty_format_batches(&[batch?])?); } Ok(()) } @@ -342,7 +347,7 @@ impl CpuRuntime { /// message such as: /// /// ```text - ///A Tokio 1.x context was found, but IO is disabled. + /// A Tokio 1.x context was found, but IO is disabled. /// ``` pub fn handle(&self) -> &Handle { &self.handle diff --git a/datafusion-examples/examples/relation_planner/main.rs b/datafusion-examples/examples/relation_planner/main.rs new file mode 100644 index 0000000000000..babc0d3714f72 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/main.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Relation Planner Examples +//! +//! These examples demonstrate how to use custom relation planners to extend +//! DataFusion's SQL syntax with custom table operators. +//! +//! ## Usage +//! ```bash +//! cargo run --example relation_planner -- [all|match_recognize|pivot_unpivot|table_sample] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `match_recognize` +//! (file: match_recognize.rs, desc: Implement MATCH_RECOGNIZE pattern matching) +//! +//! - `pivot_unpivot` +//! (file: pivot_unpivot.rs, desc: Implement PIVOT / UNPIVOT) +//! +//! - `table_sample` +//! (file: table_sample.rs, desc: Implement TABLESAMPLE) +//! +//! ## Snapshot Testing +//! +//! These examples use [insta](https://insta.rs) for inline snapshot assertions. +//! If query output changes, regenerate the snapshots with: +//! ```bash +//! cargo insta test --example relation_planner --accept +//! ``` + +mod match_recognize; +mod pivot_unpivot; +mod table_sample; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MatchRecognize, + PivotUnpivot, + TableSample, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "relation_planner"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MatchRecognize => match_recognize::match_recognize().await?, + ExampleKind::PivotUnpivot => pivot_unpivot::pivot_unpivot().await?, + ExampleKind::TableSample => table_sample::table_sample().await?, + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} + +/// Test wrappers that enable `cargo insta test --example relation_planner --accept` +/// to regenerate inline snapshots. Without these, insta cannot run the examples +/// in test mode since they only have `main()` functions. +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_match_recognize() { + match_recognize::match_recognize().await.unwrap(); + } + + #[tokio::test] + async fn test_pivot_unpivot() { + pivot_unpivot::pivot_unpivot().await.unwrap(); + } + + #[tokio::test] + async fn test_table_sample() { + table_sample::table_sample().await.unwrap(); + } +} diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs new file mode 100644 index 0000000000000..c4b3d522efc17 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # MATCH_RECOGNIZE Example +//! +//! This example demonstrates implementing SQL `MATCH_RECOGNIZE` pattern matching +//! using a custom [`RelationPlanner`]. Unlike the [`pivot_unpivot`] example that +//! rewrites SQL to standard operations, this example creates a **custom logical +//! plan node** (`MiniMatchRecognizeNode`) to represent the operation. +//! +//! ## Supported Syntax +//! +//! ```sql +//! SELECT * FROM events +//! MATCH_RECOGNIZE ( +//! PARTITION BY region +//! MEASURES SUM(price) AS total, AVG(price) AS average +//! PATTERN (A B+ C) +//! DEFINE +//! A AS price < 100, +//! B AS price BETWEEN 100 AND 200, +//! C AS price > 200 +//! ) AS matches +//! ``` +//! +//! ## Architecture +//! +//! This example demonstrates **logical planning only**. Physical execution would +//! require implementing an [`ExecutionPlan`] (see the [`table_sample`] example +//! for a complete implementation with physical planning). +//! +//! ```text +//! SQL Query +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MatchRecognizePlanner │ +//! │ (RelationPlanner trait) │ +//! │ │ +//! │ • Parses MATCH_RECOGNIZE syntax │ +//! │ • Creates MiniMatchRecognizeNode │ +//! │ • Converts SQL exprs to DataFusion │ +//! └─────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MiniMatchRecognizeNode │ +//! │ (UserDefinedLogicalNode) │ +//! │ │ +//! │ • measures: [(alias, expr), ...] │ +//! │ • definitions: [(symbol, expr), ...]│ +//! └─────────────────────────────────────┘ +//! ``` +//! +//! [`pivot_unpivot`]: super::pivot_unpivot +//! [`table_sample`]: super::table_sample +//! [`ExecutionPlan`]: datafusion::physical_plan::ExecutionPlan + +use std::{any::Any, cmp::Ordering, hash::Hasher, sync::Arc}; + +use arrow::array::{ArrayRef, Float64Array, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{DFSchemaRef, Result}; +use datafusion_expr::{ + Expr, UserDefinedLogicalNode, + logical_plan::{Extension, InvariantLevel, LogicalPlan}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the MATCH_RECOGNIZE examples demonstrating pattern matching on event streams. +/// +/// Note: This example demonstrates **logical planning only**. Physical execution +/// would require additional implementation of an [`ExecutionPlan`]. +pub async fn match_recognize() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(MatchRecognizePlanner))?; + register_sample_data(&ctx)?; + + println!("MATCH_RECOGNIZE Example (Logical Planning Only)"); + println!("================================================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Basic MATCH_RECOGNIZE with MEASURES and DEFINE + // Demonstrates: Aggregate measures over matched rows + let plan = run_example( + ctx, + "Example 1: MATCH_RECOGNIZE with aggregations", + r#"SELECT * FROM events + MATCH_RECOGNIZE ( + PARTITION BY 1 + MEASURES SUM(price) AS total_price, AVG(price) AS avg_price + PATTERN (A) + DEFINE A AS price > 10 + ) AS matches"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: matches.price + SubqueryAlias: matches + MiniMatchRecognize measures=[total_price := sum(events.price), avg_price := avg(events.price)] define=[a := events.price > Int64(10)] + TableScan: events + "); + + // Example 2: Stock price pattern detection + // Demonstrates: Real-world use case finding prices above threshold + let plan = run_example( + ctx, + "Example 2: Detect high stock prices", + r#"SELECT * FROM stock_prices + MATCH_RECOGNIZE ( + MEASURES + MIN(price) AS min_price, + MAX(price) AS max_price, + AVG(price) AS avg_price + PATTERN (HIGH) + DEFINE HIGH AS price > 151.0 + ) AS trends"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: trends.symbol, trends.price + SubqueryAlias: trends + MiniMatchRecognize measures=[min_price := min(stock_prices.price), max_price := max(stock_prices.price), avg_price := avg(stock_prices.price)] define=[high := stock_prices.price > Float64(151)] + TableScan: stock_prices + "); + + Ok(()) +} + +/// Helper to run a single example query and display the logical plan. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let plan = ctx.sql(sql).await?.into_unoptimized_plan(); + let plan_str = plan.display_indent().to_string(); + println!("{plan_str}\n"); + Ok(plan_str) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // events: simple price series + ctx.register_batch( + "events", + RecordBatch::try_from_iter(vec![( + "price", + Arc::new(Int32Array::from(vec![5, 12, 8, 15, 20])) as ArrayRef, + )])?, + )?; + + // stock_prices: realistic stock data + ctx.register_batch( + "stock_prices", + RecordBatch::try_from_iter(vec![ + ( + "symbol", + Arc::new(StringArray::from(vec!["DDOG", "DDOG", "DDOG", "DDOG"])) + as ArrayRef, + ), + ( + "price", + Arc::new(Float64Array::from(vec![150.0, 155.0, 152.0, 158.0])), + ), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Logical Plan Node: MiniMatchRecognizeNode +// ============================================================================ + +/// A custom logical plan node representing MATCH_RECOGNIZE operations. +/// +/// This is a simplified implementation that captures the essential structure: +/// - `measures`: Aggregate expressions computed over matched rows +/// - `definitions`: Symbol definitions (predicate expressions) +/// +/// A production implementation would also include: +/// - Pattern specification (regex-like pattern) +/// - Partition and order by clauses +/// - Output mode (ONE ROW PER MATCH, ALL ROWS PER MATCH) +/// - After match skip strategy +#[derive(Debug)] +struct MiniMatchRecognizeNode { + input: Arc, + schema: DFSchemaRef, + /// Measures: (alias, aggregate_expr) + measures: Vec<(String, Expr)>, + /// Symbol definitions: (symbol_name, predicate_expr) + definitions: Vec<(String, Expr)>, +} + +impl UserDefinedLogicalNode for MiniMatchRecognizeNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "MiniMatchRecognize" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { + Ok(()) + } + + fn expressions(&self) -> Vec { + self.measures + .iter() + .chain(&self.definitions) + .map(|(_, expr)| expr.clone()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MiniMatchRecognize")?; + + if !self.measures.is_empty() { + write!(f, " measures=[")?; + for (i, (alias, expr)) in self.measures.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{alias} := {expr}")?; + } + write!(f, "]")?; + } + + if !self.definitions.is_empty() { + write!(f, " define=[")?; + for (i, (symbol, expr)) in self.definitions.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{symbol} := {expr}")?; + } + write!(f, "]")?; + } + + Ok(()) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result> { + let expected_len = self.measures.len() + self.definitions.len(); + if exprs.len() != expected_len { + return Err(datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize: expected {expected_len} expressions, got {}", + exprs.len() + )); + } + + let input = inputs.into_iter().next().ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize requires exactly one input" + ) + })?; + + let (measure_exprs, definition_exprs) = exprs.split_at(self.measures.len()); + + let measures = self + .measures + .iter() + .zip(measure_exprs) + .map(|((alias, _), expr)| (alias.clone(), expr.clone())) + .collect(); + + let definitions = self + .definitions + .iter() + .zip(definition_exprs) + .map(|((symbol, _), expr)| (symbol.clone(), expr.clone())) + .collect(); + + Ok(Arc::new(Self { + input: Arc::new(input), + schema: Arc::clone(&self.schema), + measures, + definitions, + })) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + state.write_usize(Arc::as_ptr(&self.input) as usize); + state.write_usize(self.measures.len()); + state.write_usize(self.definitions.len()); + } + + fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { + other.as_any().downcast_ref::().is_some_and(|o| { + Arc::ptr_eq(&self.input, &o.input) + && self.measures == o.measures + && self.definitions == o.definitions + }) + } + + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option { + if self.dyn_eq(other) { + Some(Ordering::Equal) + } else { + None + } + } +} + +// ============================================================================ +// Relation Planner: MatchRecognizePlanner +// ============================================================================ + +/// Relation planner that creates `MiniMatchRecognizeNode` for MATCH_RECOGNIZE queries. +#[derive(Debug)] +struct MatchRecognizePlanner; + +impl RelationPlanner for MatchRecognizePlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + let TableFactor::MatchRecognize { + table, + measures, + symbols, + alias, + .. + } = relation + else { + return Ok(RelationPlanning::Original(Box::new(relation))); + }; + + // Plan the input table + let input = ctx.plan(*table)?; + let schema = input.schema().clone(); + + // Convert MEASURES: SQL expressions → DataFusion expressions + let planned_measures: Vec<(String, Expr)> = measures + .iter() + .map(|m| { + let alias = ctx.normalize_ident(m.alias.clone()); + let expr = ctx.sql_to_expr(m.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>()?; + + // Convert DEFINE: symbol definitions → DataFusion expressions + let planned_definitions: Vec<(String, Expr)> = symbols + .iter() + .map(|s| { + let name = ctx.normalize_ident(s.symbol.clone()); + let expr = ctx.sql_to_expr(s.definition.clone(), schema.as_ref())?; + Ok((name, expr)) + }) + .collect::>()?; + + // Create the custom node + let node = MiniMatchRecognizeNode { + input: Arc::new(input), + schema, + measures: planned_measures, + definitions: planned_definitions, + }; + + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(node), + }); + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } +} diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs new file mode 100644 index 0000000000000..4b721346aa72d --- /dev/null +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -0,0 +1,619 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # PIVOT and UNPIVOT Example +//! +//! This example demonstrates implementing SQL `PIVOT` and `UNPIVOT` operations +//! using a custom [`RelationPlanner`]. Unlike the other examples that create +//! custom logical/physical nodes, this example shows how to **rewrite** SQL +//! constructs into equivalent standard SQL operations: +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- PIVOT: Transform rows into columns +//! SELECT * FROM sales +//! PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4')) +//! +//! -- UNPIVOT: Transform columns into rows +//! SELECT * FROM wide_table +//! UNPIVOT (value FOR name IN (col1, col2, col3)) +//! ``` +//! +//! ## Rewrite Strategy +//! +//! **PIVOT** is rewritten to `GROUP BY` with `CASE` expressions: +//! ```sql +//! -- Original: +//! SELECT * FROM sales PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) +//! +//! -- Rewritten to: +//! SELECT region, +//! SUM(CASE quarter WHEN 'Q1' THEN amount END) AS Q1, +//! SUM(CASE quarter WHEN 'Q2' THEN amount END) AS Q2 +//! FROM sales +//! GROUP BY region +//! ``` +//! +//! **UNPIVOT** is rewritten to `UNION ALL` of projections: +//! ```sql +//! -- Original: +//! SELECT * FROM wide UNPIVOT (sales FOR quarter IN (q1, q2)) +//! +//! -- Rewritten to: +//! SELECT region, 'q1' AS quarter, q1 AS sales FROM wide +//! UNION ALL +//! SELECT region, 'q2' AS quarter, q2 AS sales FROM wide +//! ``` + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue, plan_datafusion_err}; +use datafusion_expr::{ + Expr, case, col, lit, + logical_plan::builder::LogicalPlanBuilder, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{NullInclusion, PivotValueSource, TableFactor}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the PIVOT/UNPIVOT examples demonstrating data reshaping operations. +pub async fn pivot_unpivot() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(PivotUnpivotPlanner))?; + register_sample_data(&ctx)?; + + println!("PIVOT and UNPIVOT Example"); + println!("=========================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // ----- PIVOT Examples ----- + + // Example 1: Basic PIVOT + // Transforms: (region, quarter, amount) → (region, Q1, Q2) + let results = run_example( + ctx, + "Example 1: Basic PIVOT", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+------+------+ + | region | Q1 | Q2 | + +--------+------+------+ + | North | 1000 | 1500 | + | South | 1200 | 1300 | + +--------+------+------+ + "); + + // Example 2: PIVOT with multiple aggregates + // Creates columns for each (aggregate, value) combination + let results = run_example( + ctx, + "Example 2: PIVOT with multiple aggregates", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount), AVG(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+--------+--------+--------+--------+ + | region | sum_Q1 | sum_Q2 | avg_Q1 | avg_Q2 | + +--------+--------+--------+--------+--------+ + | North | 1000 | 1500 | 1000.0 | 1500.0 | + | South | 1200 | 1300 | 1200.0 | 1300.0 | + +--------+--------+--------+--------+--------+ + "); + + // Example 3: PIVOT with multiple grouping columns + // Non-pivot, non-aggregate columns become GROUP BY columns + let results = run_example( + ctx, + "Example 3: PIVOT with multiple grouping columns", + r#"SELECT * FROM product_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region, product"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+----------+-----+-----+ + | region | product | Q1 | Q2 | + +--------+----------+-----+-----+ + | North | ProductA | 500 | | + | North | ProductB | 500 | | + | South | ProductA | | 650 | + +--------+----------+-----+-----+ + "); + + // ----- UNPIVOT Examples ----- + + // Example 4: Basic UNPIVOT + // Transforms: (region, q1, q2) → (region, quarter, sales) + let results = run_example( + ctx, + "Example 4: Basic UNPIVOT", + r#"SELECT * FROM wide_sales + UNPIVOT (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 5: UNPIVOT with INCLUDE NULLS + // By default, UNPIVOT excludes rows where the value column is NULL. + // INCLUDE NULLS keeps them (same result here since no NULLs in data). + let results = run_example( + ctx, + "Example 5: UNPIVOT INCLUDE NULLS", + r#"SELECT * FROM wide_sales + UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 6: PIVOT with column projection + // Standard SQL operations work seamlessly after PIVOT + let results = run_example( + ctx, + "Example 6: PIVOT with projection", + r#"SELECT region FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+ + | region | + +--------+ + | North | + | South | + +--------+ + "); + + // Example 7: PIVOT on a quoted mixed-case column + // Reuses the parsed column expression so quoted identifiers keep their case. + let results = run_example( + ctx, + "Example 7: PIVOT with quoted mixed-case column", + r#"SELECT * FROM point_stats + PIVOT (MAX(max_value) FOR "pointNumber" IN ('16951' AS p16951, '16952' AS p16952)) AS p + ORDER BY ts"#, + ) + .await?; + assert_snapshot!(results, @r" + +----------------------+------+--------+--------+ + | ts | port | p16951 | p16952 | + +----------------------+------+--------+--------+ + | 2024-09-01T10:00:00Z | 2411 | 10 | 20 | + | 2024-09-01T10:01:00Z | 2411 | 30 | 40 | + +----------------------+------+--------+--------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // quarterly_sales: normalized sales data (region, quarter, amount) + ctx.register_batch( + "quarterly_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South", "South"])) + as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q2", "Q1", "Q2"])), + ), + ( + "amount", + Arc::new(Int64Array::from(vec![1000, 1500, 1200, 1300])), + ), + ])?, + )?; + + // product_sales: sales with additional grouping dimension + ctx.register_batch( + "product_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South"])) as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q1", "Q2"])), + ), + ( + "product", + Arc::new(StringArray::from(vec!["ProductA", "ProductB", "ProductA"])), + ), + ("amount", Arc::new(Int64Array::from(vec![500, 500, 650]))), + ])?, + )?; + + // wide_sales: denormalized/wide format (for UNPIVOT) + ctx.register_batch( + "wide_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "South"])) as ArrayRef, + ), + ("q1", Arc::new(Int64Array::from(vec![1000, 1200]))), + ("q2", Arc::new(Int64Array::from(vec![1500, 1300]))), + ])?, + )?; + + // point_stats: grouped data with a quoted mixed-case pivot column. + ctx.register_batch( + "point_stats", + RecordBatch::try_from_iter(vec![ + ( + "ts", + Arc::new(StringArray::from(vec![ + "2024-09-01T10:00:00Z", + "2024-09-01T10:00:00Z", + "2024-09-01T10:01:00Z", + "2024-09-01T10:01:00Z", + ])) as ArrayRef, + ), + ( + "pointNumber", + Arc::new(StringArray::from(vec!["16951", "16952", "16951", "16952"])), + ), + ( + "port", + Arc::new(StringArray::from(vec!["2411", "2411", "2411", "2411"])), + ), + ( + "max_value", + Arc::new(Int64Array::from(vec![10, 20, 30, 40])), + ), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Relation Planner: PivotUnpivotPlanner +// ============================================================================ + +/// Relation planner that rewrites PIVOT and UNPIVOT into standard SQL. +#[derive(Debug)] +struct PivotUnpivotPlanner; + +impl RelationPlanner for PivotUnpivotPlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + alias, + .. + } => plan_pivot( + ctx, + *table, + &aggregate_functions, + &value_column, + value_source, + alias, + ), + + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => plan_unpivot( + ctx, + *table, + &value, + name, + &columns, + null_inclusion.as_ref(), + alias, + ), + + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +// ============================================================================ +// PIVOT Implementation +// ============================================================================ + +/// Rewrite PIVOT to GROUP BY with CASE expressions. +fn plan_pivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + aggregate_functions: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + value_column: &[datafusion_sql::sqlparser::ast::Expr], + value_source: PivotValueSource, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Parse aggregate functions + let aggregates: Vec = aggregate_functions + .iter() + .map(|agg| ctx.sql_to_expr(agg.expr.clone(), schema.as_ref())) + .collect::>()?; + + // Get the pivot column (only single-column pivot supported) + if value_column.len() != 1 { + return Err(plan_datafusion_err!( + "Only single-column PIVOT is supported" + )); + } + let pivot_col = ctx.sql_to_expr(value_column[0].clone(), schema.as_ref())?; + let pivot_col_name = extract_column_name(&pivot_col)?; + + // Parse pivot values + let pivot_values = match value_source { + PivotValueSource::List(list) => list + .iter() + .map(|item| { + let alias = item + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())); + let expr = ctx.sql_to_expr(item.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>>()?, + _ => { + return Err(plan_datafusion_err!( + "Dynamic PIVOT (ANY/Subquery) is not supported" + )); + } + }; + + // Determine GROUP BY columns (non-pivot, non-aggregate columns) + let agg_input_cols: Vec<&str> = aggregates + .iter() + .filter_map(|agg| { + if let Expr::AggregateFunction(f) = agg { + f.params.args.first().and_then(|e| { + if let Expr::Column(c) = e { + Some(c.name.as_str()) + } else { + None + } + }) + } else { + None + } + }) + .collect(); + + let group_by_cols: Vec = schema + .iter() + .filter(|(_, field)| { + let name = field.name(); + name != pivot_col_name.as_str() && !agg_input_cols.contains(&name.as_str()) + }) + .map(Expr::from) + .collect(); + + // Build CASE expressions for each (aggregate, pivot_value) pair + let mut pivot_exprs = Vec::new(); + for agg in &aggregates { + let Expr::AggregateFunction(agg_fn) = agg else { + continue; + }; + let Some(agg_input) = agg_fn.params.args.first().cloned() else { + continue; + }; + + for (value_alias, pivot_value) in &pivot_values { + // CASE pivot_col WHEN pivot_value THEN agg_input END + let case_expr = case(pivot_col.clone()) + .when(pivot_value.clone(), agg_input.clone()) + .end()?; + + // Wrap in aggregate function + let pivoted = agg_fn.func.call(vec![case_expr]); + + // Determine column alias + let value_str = value_alias + .clone() + .unwrap_or_else(|| expr_to_string(pivot_value)); + let col_alias = if aggregates.len() > 1 { + format!("{}_{}", agg_fn.func.name(), value_str) + } else { + value_str + }; + + pivot_exprs.push(pivoted.alias(col_alias)); + } + } + + let plan = LogicalPlanBuilder::from(input) + .aggregate(group_by_cols, pivot_exprs)? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) +} + +// ============================================================================ +// UNPIVOT Implementation +// ============================================================================ + +/// Rewrite UNPIVOT to UNION ALL of projections. +fn plan_unpivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + value: &datafusion_sql::sqlparser::ast::Expr, + name: datafusion_sql::sqlparser::ast::Ident, + columns: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + null_inclusion: Option<&NullInclusion>, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Output column names + let value_col_name = value.to_string(); + let name_col_name = ctx.normalize_ident(name); + + // Parse columns to unpivot: (source_column, label) + let unpivot_cols: Vec<(String, String)> = columns + .iter() + .map(|c| { + let label = c + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())) + .unwrap_or_else(|| c.expr.to_string()); + let expr = ctx.sql_to_expr(c.expr.clone(), schema.as_ref())?; + let col_name = extract_column_name(&expr)?; + Ok((col_name.to_string(), label)) + }) + .collect::>()?; + + // Columns to preserve (not being unpivoted) + let keep_cols: Vec<&str> = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .filter(|name| !unpivot_cols.iter().any(|(c, _)| c == *name)) + .collect(); + + // Build UNION ALL: one SELECT per unpivot column + if unpivot_cols.is_empty() { + return Err(plan_datafusion_err!("UNPIVOT requires at least one column")); + } + + let mut union_inputs: Vec<_> = unpivot_cols + .iter() + .map(|(col_name, label)| { + let mut projection: Vec = keep_cols.iter().map(|c| col(*c)).collect(); + projection.push(lit(label.clone()).alias(&name_col_name)); + projection.push(col(col_name).alias(&value_col_name)); + + LogicalPlanBuilder::from(input.clone()) + .project(projection)? + .build() + }) + .collect::>()?; + + // Combine with UNION ALL + let mut plan = union_inputs.remove(0); + for branch in union_inputs { + plan = LogicalPlanBuilder::from(plan).union(branch)?.build()?; + } + + // Apply EXCLUDE NULLS filter (default behavior) + let exclude_nulls = null_inclusion.is_none() + || matches!(null_inclusion, Some(&NullInclusion::ExcludeNulls)); + if exclude_nulls { + plan = LogicalPlanBuilder::from(plan) + .filter(col(&value_col_name).is_not_null())? + .build()?; + } + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) +} + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Extract column name from an expression. +fn extract_column_name(expr: &Expr) -> Result { + match expr { + Expr::Column(c) => Ok(c.name.clone()), + _ => Err(plan_datafusion_err!( + "Expected column reference, got {expr}" + )), + } +} + +/// Convert an expression to a string for use as column alias. +fn expr_to_string(expr: &Expr) -> String { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(), + Expr::Literal(v, _) => v.to_string(), + other => other.to_string(), + } +} diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs new file mode 100644 index 0000000000000..42342e5f1a641 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -0,0 +1,831 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # TABLESAMPLE Example +//! +//! This example demonstrates implementing SQL `TABLESAMPLE` support using +//! DataFusion's extensibility APIs. +//! +//! This is a working `TABLESAMPLE` implementation that can serve as a starting +//! point for your own projects. It also works as a template for adding other +//! custom SQL operators, covering the full pipeline from parsing to execution. +//! +//! It shows how to: +//! +//! 1. **Parse** TABLESAMPLE syntax via a custom [`RelationPlanner`] +//! 2. **Plan** sampling as a custom logical node ([`TableSamplePlanNode`]) +//! 3. **Execute** sampling via a custom physical operator ([`SampleExec`]) +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- Bernoulli sampling (each row has N% chance of selection) +//! SELECT * FROM table TABLESAMPLE BERNOULLI(10 PERCENT) +//! +//! -- Fractional sampling (0.0 to 1.0) +//! SELECT * FROM table TABLESAMPLE (0.1) +//! +//! -- Row count limit +//! SELECT * FROM table TABLESAMPLE (100 ROWS) +//! +//! -- Reproducible sampling with a seed +//! SELECT * FROM table TABLESAMPLE (10 PERCENT) REPEATABLE(42) +//! ``` +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SQL Query │ +//! │ SELECT * FROM t TABLESAMPLE BERNOULLI(10 PERCENT) REPEATABLE(1)│ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanner │ +//! │ (RelationPlanner: parses TABLESAMPLE, creates logical node) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanNode │ +//! │ (UserDefinedLogicalNode: stores sampling params) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSampleExtensionPlanner │ +//! │ (ExtensionPlanner: creates physical execution plan) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SampleExec │ +//! │ (ExecutionPlan: performs actual row sampling at runtime) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` + +use std::{ + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::datatypes::{Float64Type, Int64Type}; +use arrow::{ + array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, + compute, +}; +use arrow_schema::SchemaRef; +use futures::{ + ready, + stream::{Stream, StreamExt}, +}; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use tonic::async_trait; + +use datafusion::optimizer::simplify_expressions::simplify_literal::parse_literal; +use datafusion::{ + execution::{ + RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder, + TaskContext, context::QueryPlanner, + }, + physical_expr::EquivalenceProperties, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, + prelude::*, +}; +use datafusion_common::{ + DFSchemaRef, DataFusionError, Result, Statistics, internal_err, not_impl_err, + plan_datafusion_err, plan_err, tree_node::TreeNodeRecursion, +}; +use datafusion_expr::{ + UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + logical_plan::{Extension, LogicalPlan, LogicalPlanBuilder}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{ + self, TableFactor, TableSampleMethod, TableSampleUnit, +}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the TABLESAMPLE examples demonstrating various sampling techniques. +pub async fn table_sample() -> Result<()> { + // Build session with custom query planner for physical planning + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner)) + .build(); + + let ctx = SessionContext::new_with_state(state); + + // Register custom relation planner for logical planning + ctx.register_relation_planner(Arc::new(TableSamplePlanner))?; + register_sample_data(&ctx)?; + + println!("TABLESAMPLE Example"); + println!("===================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Baseline - full table scan + let results = run_example( + ctx, + "Example 1: Full table (baseline)", + "SELECT * FROM sample_data", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + | 6 | row_6 | + | 7 | row_7 | + | 8 | row_8 | + | 9 | row_9 | + | 10 | row_10 | + +---------+---------+ + "); + + // Example 2: Percentage-based Bernoulli sampling + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 2: BERNOULLI percentage sampling", + "SELECT * FROM sample_data TABLESAMPLE BERNOULLI(30 PERCENT) REPEATABLE(123)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 7 | row_7 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 3: Fractional sampling (0.0 to 1.0) + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 3: Fractional sampling", + "SELECT * FROM sample_data TABLESAMPLE (0.5) REPEATABLE(456)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 2 | row_2 | + | 4 | row_4 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 4: Row count limit (deterministic, no seed needed) + let results = run_example( + ctx, + "Example 4: Row count limit", + "SELECT * FROM sample_data TABLESAMPLE (3 ROWS)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + +---------+---------+ + "); + + // Example 5: Sampling combined with filtering + let results = run_example( + ctx, + "Example 5: Sampling with WHERE clause", + "SELECT * FROM sample_data TABLESAMPLE (5 ROWS) WHERE column1 > 2", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + +---------+---------+ + "); + + // Example 6: Sampling in JOIN queries + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 6: Sampling in JOINs", + r#"SELECT t1.column1, t2.column1, t1.column2, t2.column2 + FROM sample_data t1 TABLESAMPLE (0.7) REPEATABLE(789) + JOIN sample_data t2 TABLESAMPLE (0.7) REPEATABLE(123) + ON t1.column1 = t2.column1"#, + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+---------+---------+ + | column1 | column1 | column2 | column2 | + +---------+---------+---------+---------+ + | 2 | 2 | row_2 | row_2 | + | 5 | 5 | row_5 | row_5 | + | 7 | 7 | row_7 | row_7 | + | 8 | 8 | row_8 | row_8 | + | 10 | 10 | row_10 | row_10 | + +---------+---------+---------+---------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data: 10 rows with column1=1..10 and column2="row_1".."row_10" +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + let column1: ArrayRef = Arc::new(Int32Array::from((1..=10).collect::>())); + let column2: ArrayRef = Arc::new(StringArray::from( + (1..=10).map(|i| format!("row_{i}")).collect::>(), + )); + let batch = + RecordBatch::try_from_iter(vec![("column1", column1), ("column2", column2)])?; + ctx.register_batch("sample_data", batch)?; + Ok(()) +} + +// ============================================================================ +// Logical Planning: TableSamplePlanner + TableSamplePlanNode +// ============================================================================ + +/// Relation planner that intercepts `TABLESAMPLE` clauses in SQL and creates +/// [`TableSamplePlanNode`] logical nodes. +#[derive(Debug)] +struct TableSamplePlanner; + +impl RelationPlanner for TableSamplePlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + // Only handle Table relations with TABLESAMPLE clause + let TableFactor::Table { + sample: Some(sample), + alias, + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + } = relation + else { + return Ok(RelationPlanning::Original(Box::new(relation))); + }; + + // Extract sample spec (handles both before/after alias positions) + let sample = match sample { + ast::TableSampleKind::BeforeTableAlias(s) + | ast::TableSampleKind::AfterTableAlias(s) => s, + }; + + // Validate sampling method + if let Some(method) = &sample.name + && *method != TableSampleMethod::Bernoulli + && *method != TableSampleMethod::Row + { + return not_impl_err!( + "Sampling method {} is not supported (only BERNOULLI and ROW)", + method + ); + } + + // Offset sampling (ClickHouse-style) not supported + if sample.offset.is_some() { + return not_impl_err!( + "TABLESAMPLE with OFFSET is not supported (requires total row count)" + ); + } + + // Parse optional REPEATABLE seed + let seed = sample + .seed + .map(|s| { + s.value.to_string().parse::().map_err(|_| { + plan_datafusion_err!("REPEATABLE seed must be an integer") + }) + }) + .transpose()?; + + // Plan the underlying table without the sample clause + let base_relation = TableFactor::Table { + sample: None, + alias: alias.clone(), + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + }; + let input = context.plan(base_relation)?; + + // Handle bucket sampling (Hive-style: TABLESAMPLE(BUCKET x OUT OF y)) + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + return not_impl_err!( + "TABLESAMPLE BUCKET with ON clause requires CLUSTERED BY table" + ); + } + let bucket_num: u64 = + bucket.bucket.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket number must be an integer") + })?; + let total: u64 = + bucket.total.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket total must be an integer") + })?; + + let fraction = bucket_num as f64 / total as f64; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + return Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))); + } + + // Handle quantity-based sampling + let Some(quantity) = sample.quantity else { + return plan_err!( + "TABLESAMPLE requires a quantity (percentage, fraction, or row count)" + ); + }; + let quantity_value_expr = context.sql_to_expr(quantity.value, input.schema())?; + + match quantity.unit { + // TABLESAMPLE (N ROWS) - exact row limit + Some(TableSampleUnit::Rows) => { + let rows: i64 = parse_literal::(&quantity_value_expr)?; + if rows < 0 { + return plan_err!("row count must be non-negative, got {}", rows); + } + let plan = LogicalPlanBuilder::from(input) + .limit(0, Some(rows as usize))? + .build()?; + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + + // TABLESAMPLE (N PERCENT) - percentage sampling + Some(TableSampleUnit::Percent) => { + let percent: f64 = parse_literal::(&quantity_value_expr)?; + let fraction = percent / 100.0; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + + // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 + None => { + let value = parse_literal::(&quantity_value_expr)?; + if value < 0.0 { + return plan_err!("sample value must be non-negative, got {}", value); + } + let plan = if value >= 1.0 { + // Interpret as row limit + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()? + } else { + // Interpret as fraction + TableSamplePlanNode::new(input, value, seed).into_plan() + }; + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + } + } +} + +/// Custom logical plan node representing a TABLESAMPLE operation. +/// +/// Stores sampling parameters (bounds, seed) and wraps the input plan. +/// Gets converted to [`SampleExec`] during physical planning. +#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] +struct TableSamplePlanNode { + input: LogicalPlan, + lower_bound: HashableF64, + upper_bound: HashableF64, + seed: u64, +} + +impl TableSamplePlanNode { + /// Create a new sampling node with the given fraction (0.0 to 1.0). + fn new(input: LogicalPlan, fraction: f64, seed: Option) -> Self { + Self { + input, + lower_bound: HashableF64(0.0), + upper_bound: HashableF64(fraction), + seed: seed.unwrap_or_else(rand::random), + } + } + + /// Wrap this node in a LogicalPlan::Extension. + fn into_plan(self) -> LogicalPlan { + LogicalPlan::Extension(Extension { + node: Arc::new(self), + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Sample: bounds=[{}, {}], seed={}", + self.lower_bound.0, self.upper_bound.0, self.seed + ) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs.swap_remove(0), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + seed: self.seed, + }) + } +} + +/// Wrapper for f64 that implements Hash and Eq (required for LogicalPlan). +#[derive(Debug, Clone, Copy, PartialOrd)] +struct HashableF64(f64); + +impl PartialEq for HashableF64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for HashableF64 {} + +impl Hash for HashableF64 { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + +// ============================================================================ +// Physical Planning: TableSampleQueryPlanner + TableSampleExtensionPlanner +// ============================================================================ + +/// Custom query planner that registers [`TableSampleExtensionPlanner`] to +/// convert [`TableSamplePlanNode`] into [`SampleExec`]. +#[derive(Debug)] +struct TableSampleQueryPlanner; + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner, + )]); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Extension planner that converts [`TableSamplePlanNode`] to [`SampleExec`]. +struct TableSampleExtensionPlanner; + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + let Some(sample_node) = node.as_any().downcast_ref::() + else { + return Ok(None); + }; + + let exec = SampleExec::try_new( + Arc::clone(&physical_inputs[0]), + sample_node.lower_bound.0, + sample_node.upper_bound.0, + sample_node.seed, + )?; + Ok(Some(Arc::new(exec))) + } +} + +// ============================================================================ +// Physical Execution: SampleExec + BernoulliSampler +// ============================================================================ + +/// Physical execution plan that samples rows from its input using Bernoulli sampling. +/// +/// Each row is independently selected with probability `(upper_bound - lower_bound)` +/// and appears at most once. +#[derive(Debug, Clone)] +pub struct SampleExec { + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + metrics: ExecutionPlanMetricsSet, + cache: Arc, +} + +impl SampleExec { + /// Create a new SampleExec with Bernoulli sampling (without replacement). + /// + /// # Arguments + /// * `input` - The input execution plan + /// * `lower_bound` - Lower bound of sampling range (typically 0.0) + /// * `upper_bound` - Upper bound of sampling range (0.0 to 1.0) + /// * `seed` - Random seed for reproducible sampling + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must satisfy 0.0 <= lower <= upper <= 1.0, got [{}, {}]", + lower_bound, + upper_bound + ); + } + + let cache = PlanProperties::new( + EquivalenceProperties::new(input.schema()), + input.properties().partitioning.clone(), + input.properties().emission_type, + input.properties().boundedness, + ); + + Ok(Self { + input, + lower_bound, + upper_bound, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache: Arc::new(cache), + }) + } + + /// Create a sampler for the given partition. + fn create_sampler(&self, partition: usize) -> BernoulliSampler { + let seed = self.seed.wrapping_add(partition as u64); + BernoulliSampler::new(self.lower_bound, self.upper_bound, seed) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!( + f, + "SampleExec: bounds=[{}, {}], seed={}", + self.lower_bound, self.upper_bound, self.seed + ) + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn maintains_input_order(&self) -> Vec { + // Sampling preserves row order (rows are filtered, not reordered) + vec![true] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::try_new( + children.swap_remove(0), + self.lower_bound, + self.upper_bound, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(SampleStream { + input: self.input.execute(partition, context)?, + sampler: self.create_sampler(partition), + metrics: BaselineMetrics::new(&self.metrics, partition), + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + let mut stats = Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + let ratio = self.upper_bound - self.lower_bound; + + // Scale statistics by sampling ratio (inexact due to randomness) + stats.num_rows = stats + .num_rows + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + + Ok(Arc::new(stats)) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } +} + +/// Bernoulli sampler: includes each row with probability `(upper - lower)`. +/// This is sampling **without replacement** - each row appears at most once. +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } + + fn sample(&mut self, batch: &RecordBatch) -> Result { + let range = self.upper_bound - self.lower_bound; + if range <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + // Select rows where random value falls in [lower, upper) + let indices: Vec = (0..batch.num_rows()) + .filter(|_| { + let r: f64 = self.rng.random(); + r >= self.lower_bound && r < self.upper_bound + }) + .map(|i| i as u32) + .collect(); + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + compute::take_record_batch(batch, &UInt32Array::from(indices)) + .map_err(DataFusionError::from) + } +} + +/// Stream adapter that applies sampling to each batch. +struct SampleStream { + input: SendableRecordBatchStream, + sampler: BernoulliSampler, + metrics: BaselineMetrics, +} + +impl Stream for SampleStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let elapsed = self.metrics.elapsed_compute().clone(); + let _timer = elapsed.timer(); + let result = self.sampler.sample(&batch); + Poll::Ready(Some(result.record_output(&self.metrics))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for SampleStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs deleted file mode 100644 index 20b515506f3b4..0000000000000 --- a/datafusion-examples/examples/sql_dialect.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::fmt::Display; - -use datafusion::error::{DataFusionError, Result}; -use datafusion::sql::{ - parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, tokenizer::Token}, -}; - -/// This example demonstrates how to use the DFParser to parse a statement in a custom way -/// -/// This technique can be used to implement a custom SQL dialect, for example. -#[tokio::main] -async fn main() -> Result<()> { - let mut my_parser = - MyParser::new("COPY source_table TO 'file.fasta' STORED AS FASTA")?; - - let my_statement = my_parser.parse_statement()?; - - match my_statement { - MyStatement::DFStatement(s) => println!("df: {s}"), - MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), - } - - Ok(()) -} - -/// Here we define a Parser for our new SQL dialect that wraps the existing `DFParser` -struct MyParser<'a> { - df_parser: DFParser<'a>, -} - -impl<'a> MyParser<'a> { - fn new(sql: &'a str) -> Result { - let df_parser = DFParserBuilder::new(sql).build()?; - Ok(Self { df_parser }) - } - - /// Returns true if the next token is `COPY` keyword, false otherwise - fn is_copy(&self) -> bool { - matches!( - self.df_parser.parser.peek_token().token, - Token::Word(w) if w.keyword == Keyword::COPY - ) - } - - /// This is the entry point to our parser -- it handles `COPY` statements specially - /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { - if self.is_copy() { - self.df_parser.parser.next_token(); // COPY - let df_statement = self.df_parser.parse_copy()?; - - if let Statement::CopyTo(s) = df_statement { - Ok(MyStatement::from(s)) - } else { - Ok(MyStatement::DFStatement(Box::from(df_statement))) - } - } else { - let df_statement = self.df_parser.parse_statement()?; - Ok(MyStatement::from(df_statement)) - } - } -} - -enum MyStatement { - DFStatement(Box), - MyCopyTo(MyCopyToStatement), -} - -impl Display for MyStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MyStatement::DFStatement(s) => write!(f, "{s}"), - MyStatement::MyCopyTo(s) => write!(f, "{s}"), - } - } -} - -impl From for MyStatement { - fn from(s: Statement) -> Self { - Self::DFStatement(Box::from(s)) - } -} - -impl From for MyStatement { - fn from(s: CopyToStatement) -> Self { - if s.stored_as == Some("FASTA".to_string()) { - Self::MyCopyTo(MyCopyToStatement::from(s)) - } else { - Self::DFStatement(Box::from(Statement::CopyTo(s))) - } - } -} - -struct MyCopyToStatement { - pub source: CopyToSource, - pub target: String, -} - -impl From for MyCopyToStatement { - fn from(s: CopyToStatement) -> Self { - Self { - source: s.source, - target: s.target, - } - } -} - -impl Display for MyCopyToStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "COPY {} TO '{}' STORED AS FASTA", - self.source, self.target - ) - } -} diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_ops/analysis.rs similarity index 98% rename from datafusion-examples/examples/sql_analysis.rs rename to datafusion-examples/examples/sql_ops/analysis.rs index 4ff669faf1d0c..4243a2927865b 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_ops/analysis.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to use the structures that DataFusion provides to perform //! Analysis on SQL queries and their plans. //! @@ -23,8 +25,8 @@ use std::sync::Arc; -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::Result; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::logical_expr::LogicalPlan; use datafusion::{ datasource::MemTable, @@ -32,141 +34,9 @@ use datafusion::{ }; use test_utils::tpcds::tpcds_schemas; -/// Counts the total number of joins in a plan -fn total_join_count(plan: &LogicalPlan) -> usize { - let mut total = 0; - - // We can use the TreeNode API to walk over a LogicalPlan. - plan.apply(|node| { - // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - } - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - total -} - -/// Counts the total number of joins in a plan and collects every join tree in -/// the plan with their respective join count. -/// -/// Join Tree Definition: the largest subtree consisting entirely of joins -/// -/// For example, this plan: -/// -/// ```text -/// JOIN -/// / \ -/// A JOIN -/// / \ -/// B C -/// ``` -/// -/// has a single join tree `(A-B-C)` which will result in `(2, [2])` -/// -/// This plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` -fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { - // this works the same way as `total_count`, but now when we encounter a Join - // we try to collect it's entire tree - let mut to_visit = vec![plan]; - let mut total = 0; - let mut groups = vec![]; - - while let Some(node) = to_visit.pop() { - // if we encounter a join, we know were at the root of the tree - // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_)) { - let (group_count, inputs) = count_tree(node); - total += group_count; - groups.push(group_count); - to_visit.extend(inputs); - } else { - to_visit.extend(node.inputs()); - } - } - - (total, groups) -} - -/// Count the entire join tree and return its inputs using TreeNode API -/// -/// For example, if this function receives following plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// It will return `(1, [A, GROUP])` -fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { - let mut inputs = Vec::new(); - let mut total = 0; - - join.apply(|node| { - // Some extra knowledge: - // - // optimized plans have their projections pushed down as far as - // possible, which sometimes results in a projection going in between 2 - // subsequent joins giving the illusion these joins are not "related", - // when in fact they are. - // - // This plan: - // JOIN - // / \ - // A PROJECTION - // | - // JOIN - // / \ - // B C - // - // is the same as: - // - // JOIN - // / \ - // A JOIN - // / \ - // B C - // we can continue the recursion in this case - if let LogicalPlan::Projection(_) = node { - return Ok(TreeNodeRecursion::Continue); - } - - // any join we count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - Ok(TreeNodeRecursion::Continue) - } else { - inputs.push(node); - // skip children of input node - Ok(TreeNodeRecursion::Jump) - } - }) - .unwrap(); - - (total, inputs) -} - -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates how to analyze a SQL query by counting JOINs and identifying +/// join-trees using DataFusion’s `LogicalPlan` and `TreeNode` API. +pub async fn analysis() -> Result<()> { // To show how we can count the joins in a sql query we'll be using query 88 // from the TPC-DS benchmark. // @@ -310,3 +180,136 @@ from Ok(()) } + +/// Counts the total number of joins in a plan +fn total_join_count(plan: &LogicalPlan) -> usize { + let mut total = 0; + + // We can use the TreeNode API to walk over a LogicalPlan. + plan.apply(|node| { + // if we encounter a join we update the running count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + total +} + +/// Counts the total number of joins in a plan and collects every join tree in +/// the plan with their respective join count. +/// +/// Join Tree Definition: the largest subtree consisting entirely of joins +/// +/// For example, this plan: +/// +/// ```text +/// JOIN +/// / \ +/// A JOIN +/// / \ +/// B C +/// ``` +/// +/// has a single join tree `(A-B-C)` which will result in `(2, [2])` +/// +/// This plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` +fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { + // this works the same way as `total_count`, but now when we encounter a Join + // we try to collect it's entire tree + let mut to_visit = vec![plan]; + let mut total = 0; + let mut groups = vec![]; + + while let Some(node) = to_visit.pop() { + // if we encounter a join, we know were at the root of the tree + // count this tree and recurse on it's inputs + if matches!(node, LogicalPlan::Join(_)) { + let (group_count, inputs) = count_tree(node); + total += group_count; + groups.push(group_count); + to_visit.extend(inputs); + } else { + to_visit.extend(node.inputs()); + } + } + + (total, groups) +} + +/// Count the entire join tree and return its inputs using TreeNode API +/// +/// For example, if this function receives following plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// It will return `(1, [A, GROUP])` +fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { + let mut inputs = Vec::new(); + let mut total = 0; + + join.apply(|node| { + // Some extra knowledge: + // + // optimized plans have their projections pushed down as far as + // possible, which sometimes results in a projection going in between 2 + // subsequent joins giving the illusion these joins are not "related", + // when in fact they are. + // + // This plan: + // JOIN + // / \ + // A PROJECTION + // | + // JOIN + // / \ + // B C + // + // is the same as: + // + // JOIN + // / \ + // A JOIN + // / \ + // B C + // we can continue the recursion in this case + if let LogicalPlan::Projection(_) = node { + return Ok(TreeNodeRecursion::Continue); + } + + // any join we count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + Ok(TreeNodeRecursion::Continue) + } else { + inputs.push(node); + // skip children of input node + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + (total, inputs) +} diff --git a/datafusion-examples/examples/sql_ops/custom_sql_parser.rs b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs new file mode 100644 index 0000000000000..308a0de62a242 --- /dev/null +++ b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs @@ -0,0 +1,420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates extending the DataFusion SQL parser to support +//! custom DDL statements, specifically `CREATE EXTERNAL CATALOG`. +//! +//! ### Custom Syntax +//! ```sql +//! CREATE EXTERNAL CATALOG my_catalog +//! STORED AS ICEBERG +//! LOCATION 's3://my-bucket/warehouse/' +//! OPTIONS ( +//! 'region' = 'us-west-2' +//! ); +//! ``` +//! +//! Note: For the purpose of this example, we use `local://workspace/` to +//! automatically discover and register files from the project's test data. + +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; + +use datafusion::catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, + TableProviderFactory, +}; +use datafusion::datasource::listing_table_factory::ListingTableFactory; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::SessionContext; +use datafusion::sql::{ + parser::{DFParser, DFParserBuilder, Statement}, + sqlparser::{ + ast::{ObjectName, Value}, + keywords::Keyword, + tokenizer::Token, + }, +}; +use datafusion_common::{DFSchema, TableReference, plan_datafusion_err, plan_err}; +use datafusion_expr::CreateExternalTable; +use futures::StreamExt; +use insta::assert_snapshot; +use object_store::ObjectStore; +use object_store::local::LocalFileSystem; + +/// Entry point for the example. +pub async fn custom_sql_parser() -> Result<()> { + // Use standard Parquet testing data as our "external" source. + let base_path = datafusion::common::test_util::parquet_test_data(); + let base_path = std::path::Path::new(&base_path).canonicalize()?; + + // Make the path relative to the workspace root + let workspace_root = workspace_root(); + let location = base_path + .strip_prefix(&workspace_root) + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| base_path.to_string_lossy().to_string()); + + let create_catalog_sql = format!( + "CREATE EXTERNAL CATALOG parquet_testing + STORED AS parquet + LOCATION 'local://workspace/{location}' + OPTIONS ( + 'schema_name' = 'staged_data', + 'format.pruning' = 'true' + )" + ); + + // ========================================================================= + // Part 1: Standard DataFusion parser rejects the custom DDL + // ========================================================================= + println!("=== Part 1: Standard DataFusion Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx_standard = SessionContext::new(); + let err = ctx_standard + .sql(&create_catalog_sql) + .await + .expect_err("Expected the standard parser to reject CREATE EXTERNAL CATALOG (custom DDL syntax)"); + + println!("Error: {err}\n"); + assert_snapshot!(err.to_string(), @r#"SQL error: ParserError("Expected: TABLE, found: CATALOG at Line: 1, Column: 17")"#); + + // ========================================================================= + // Part 2: Custom parser handles the statement + // ========================================================================= + println!("=== Part 2: Custom Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx = SessionContext::new(); + + let mut parser = CustomParser::new(&create_catalog_sql)?; + let statement = parser.parse_statement()?; + match statement { + CustomStatement::CreateExternalCatalog(stmt) => { + handle_create_external_catalog(&ctx, stmt).await?; + } + CustomStatement::DFStatement(_) => { + panic!("Expected CreateExternalCatalog statement"); + } + } + + // Query a table from the registered catalog + let query_sql = "SELECT id, bool_col, tinyint_col FROM parquet_testing.staged_data.alltypes_plain LIMIT 5"; + println!("Executing: {query_sql}\n"); + + let results = execute_sql(&ctx, query_sql).await?; + println!("{results}"); + assert_snapshot!(results, @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + +----+----------+-------------+ + "); + + Ok(()) +} + +/// Execute SQL and return formatted results. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result { + let batches = ctx.sql(sql).await?.collect().await?; + Ok(arrow::util::pretty::pretty_format_batches(&batches)?.to_string()) +} + +/// Custom handler for the `CREATE EXTERNAL CATALOG` statement. +async fn handle_create_external_catalog( + ctx: &SessionContext, + stmt: CreateExternalCatalog, +) -> Result<()> { + let factory = ListingTableFactory::new(); + let catalog = Arc::new(MemoryCatalogProvider::new()); + let schema = Arc::new(MemorySchemaProvider::new()); + + // Extract options + let mut schema_name = "public".to_string(); + let mut table_options = HashMap::new(); + + for (k, v) in stmt.options { + let val_str = match v { + Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s) => { + s.to_string() + } + Value::Number(ref n, _) => n.to_string(), + Value::Boolean(b) => b.to_string(), + _ => v.to_string(), + }; + + if k == "schema_name" { + schema_name = val_str; + } else { + table_options.insert(k, val_str); + } + } + + println!(" Target Catalog: {}", stmt.name); + println!(" Data Location: {}", stmt.location); + println!(" Resolved Schema: {schema_name}"); + + // Register a local object store rooted at the workspace root. + // We use a specific authority 'workspace' to ensure consistent resolution. + let store = Arc::new(LocalFileSystem::new_with_prefix(workspace_root())?); + let store_url = url::Url::parse("local://workspace").unwrap(); + ctx.register_object_store(&store_url, Arc::clone(&store) as _); + + let target_ext = format!(".{}", stmt.catalog_type.to_lowercase()); + + // For 'local://workspace/parquet-testing/data', the path is 'parquet-testing/data'. + let path_str = stmt + .location + .strip_prefix("local://workspace/") + .unwrap_or(&stmt.location); + let prefix = object_store::path::Path::from(path_str); + + // Discover data files using the ObjectStore API + let mut table_count = 0; + let mut list_stream = store.list(Some(&prefix)); + + while let Some(meta) = list_stream.next().await { + let meta = meta?; + let path = &meta.location; + + if path.as_ref().ends_with(&target_ext) { + let name = std::path::Path::new(path.as_ref()) + .file_stem() + .unwrap() + .to_string_lossy() + .to_string(); + + let table_url = format!("local://workspace/{path}"); + + let cmd = CreateExternalTable::builder( + TableReference::bare(name.clone()), + table_url, + stmt.catalog_type.clone(), + Arc::new(DFSchema::empty()), + ) + .with_options(table_options.clone()) + .build(); + + match factory.create(&ctx.state(), &cmd).await { + Ok(table) => { + schema.register_table(name, table)?; + table_count += 1; + } + Err(e) => { + eprintln!("Failed to create table {name}: {e}"); + } + } + } + } + println!(" Registered {table_count} tables into schema: {schema_name}"); + + catalog.register_schema(&schema_name, schema)?; + ctx.register_catalog(stmt.name.to_string(), catalog); + + Ok(()) +} + +/// Possible statements returned by our custom parser. +#[derive(Debug, Clone)] +pub enum CustomStatement { + /// Standard DataFusion statement + DFStatement(Box), + /// Custom `CREATE EXTERNAL CATALOG` statement + CreateExternalCatalog(CreateExternalCatalog), +} + +/// Data structure for `CREATE EXTERNAL CATALOG`. +#[derive(Debug, Clone)] +pub struct CreateExternalCatalog { + pub name: ObjectName, + pub catalog_type: String, + pub location: String, + pub options: Vec<(String, Value)>, +} + +impl Display for CustomStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::DFStatement(s) => write!(f, "{s}"), + Self::CreateExternalCatalog(s) => write!(f, "{s}"), + } + } +} + +impl Display for CreateExternalCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CREATE EXTERNAL CATALOG {} STORED AS {} LOCATION '{}'", + self.name, self.catalog_type, self.location + )?; + if !self.options.is_empty() { + write!(f, " OPTIONS (")?; + for (i, (k, v)) in self.options.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "'{k}' = '{v}'")?; + } + write!(f, ")")?; + } + Ok(()) + } +} + +/// A parser that extends `DFParser` with custom syntax. +struct CustomParser<'a> { + df_parser: DFParser<'a>, +} + +impl<'a> CustomParser<'a> { + fn new(sql: &'a str) -> Result { + Ok(Self { + df_parser: DFParserBuilder::new(sql).build()?, + }) + } + + pub fn parse_statement(&mut self) -> Result { + if self.is_create_external_catalog() { + return self.parse_create_external_catalog(); + } + Ok(CustomStatement::DFStatement(Box::new( + self.df_parser.parse_statement()?, + ))) + } + + fn is_create_external_catalog(&self) -> bool { + let t1 = &self.df_parser.parser.peek_nth_token(0).token; + let t2 = &self.df_parser.parser.peek_nth_token(1).token; + let t3 = &self.df_parser.parser.peek_nth_token(2).token; + + matches!(t1, Token::Word(w) if w.keyword == Keyword::CREATE) + && matches!(t2, Token::Word(w) if w.keyword == Keyword::EXTERNAL) + && matches!(t3, Token::Word(w) if w.value.to_uppercase() == "CATALOG") + } + + fn parse_create_external_catalog(&mut self) -> Result { + // Consume prefix tokens: CREATE EXTERNAL CATALOG + for _ in 0..3 { + self.df_parser.parser.next_token(); + } + + let name = self + .df_parser + .parser + .parse_object_name(false) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let mut catalog_type = None; + let mut location = None; + let mut options = vec![]; + + while let Some(keyword) = self.df_parser.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::LOCATION, + Keyword::OPTIONS, + ]) { + match keyword { + Keyword::STORED => { + if catalog_type.is_some() { + return plan_err!("Duplicate STORED AS"); + } + self.df_parser + .parser + .expect_keyword(Keyword::AS) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + catalog_type = Some( + self.df_parser + .parser + .parse_identifier() + .map_err(|e| DataFusionError::External(Box::new(e)))? + .value, + ); + } + Keyword::LOCATION => { + if location.is_some() { + return plan_err!("Duplicate LOCATION"); + } + location = Some( + self.df_parser + .parser + .parse_literal_string() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + ); + } + Keyword::OPTIONS => { + if !options.is_empty() { + return plan_err!("Duplicate OPTIONS"); + } + options = self.parse_value_options()?; + } + _ => unreachable!(), + } + } + + Ok(CustomStatement::CreateExternalCatalog( + CreateExternalCatalog { + name, + catalog_type: catalog_type + .ok_or_else(|| plan_datafusion_err!("Missing STORED AS"))?, + location: location + .ok_or_else(|| plan_datafusion_err!("Missing LOCATION"))?, + options, + }, + )) + } + + /// Parse options in the form: (key [=] value, key [=] value, ...) + fn parse_value_options(&mut self) -> Result> { + let mut options = vec![]; + self.df_parser + .parser + .expect_token(&Token::LParen) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + loop { + let key = self.df_parser.parse_option_key()?; + // Support optional '=' between key and value + let _ = self.df_parser.parser.consume_token(&Token::Eq); + let value = self.df_parser.parse_option_value()?; + options.push((key, value)); + + let comma = self.df_parser.parser.consume_token(&Token::Comma); + if self.df_parser.parser.consume_token(&Token::RParen) { + break; + } else if !comma { + return plan_err!("Expected ',' or ')' in OPTIONS"); + } + } + Ok(options) + } +} + +/// Returns the workspace root directory (parent of datafusion-examples). +fn workspace_root() -> std::path::PathBuf { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .expect("CARGO_MANIFEST_DIR should have a parent") + .to_path_buf() +} diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs similarity index 94% rename from datafusion-examples/examples/sql_frontend.rs rename to datafusion-examples/examples/sql_ops/frontend.rs index 1fc9ce24ecbb5..b34c720a78198 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::common::{plan_err, TableReference}; +use datafusion::common::{TableReference, plan_err}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, - WindowUDF, + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, + TableProviderFilterPushDown, TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -29,7 +31,6 @@ use datafusion::optimizer::{ use datafusion::sql::planner::{ContextProvider, SqlToRel}; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; use datafusion::sql::sqlparser::parser::Parser; -use std::any::Any; use std::sync::Arc; /// This example shows how to use DataFusion's SQL planner to parse SQL text and @@ -44,7 +45,7 @@ use std::sync::Arc; /// /// In this example, we demonstrate how to use the lower level APIs directly, /// which only requires the `datafusion-sql` dependency. -pub fn main() -> Result<()> { +pub fn frontend() -> Result<()> { // First, we parse the SQL string. Note that we use the DataFusion // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion // specific syntax such as `CREATE EXTERNAL TABLE` @@ -153,6 +154,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_higher_order_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } @@ -173,6 +178,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn higher_order_function_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } @@ -188,10 +197,6 @@ struct MyTableSource { } impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/datafusion-examples/examples/sql_ops/main.rs b/datafusion-examples/examples/sql_ops/main.rs new file mode 100644 index 0000000000000..ce7be8fa2bada --- /dev/null +++ b/datafusion-examples/examples/sql_ops/main.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # SQL Examples +//! +//! These examples demonstrate SQL operations in DataFusion. +//! +//! ## Usage +//! ```bash +//! cargo run --example sql_ops -- [all|analysis|custom_sql_parser|frontend|query] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `analysis` +//! (file: analysis.rs, desc: Analyze SQL queries) +//! +//! - `custom_sql_parser` +//! (file: custom_sql_parser.rs, desc: Implement a custom SQL parser to extend DataFusion) +//! +//! - `frontend` +//! (file: frontend.rs, desc: Build LogicalPlans from SQL) +//! +//! - `query` +//! (file: query.rs, desc: Query data using SQL) + +mod analysis; +mod custom_sql_parser; +mod frontend; +mod query; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Analysis, + CustomSqlParser, + Frontend, + Query, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "sql_ops"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Analysis => analysis::analysis().await?, + ExampleKind::CustomSqlParser => { + custom_sql_parser::custom_sql_parser().await? + } + ExampleKind::Frontend => frontend::frontend()?, + ExampleKind::Query => query::query().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/sql_query.rs b/datafusion-examples/examples/sql_ops/query.rs similarity index 66% rename from datafusion-examples/examples/sql_query.rs rename to datafusion-examples/examples/sql_ops/query.rs index 0ac203cfb7e74..60b47c36b9ae2 100644 --- a/datafusion-examples/examples/sql_query.rs +++ b/datafusion-examples/examples/sql_ops/query.rs @@ -15,26 +15,27 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::array::{UInt64Array, UInt8Array}; +//! See `main.rs` for how to run it. + +use std::sync::Arc; + +use datafusion::arrow::array::{UInt8Array, UInt64Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::MemTable; use datafusion::common::{assert_batches_eq, exec_datafusion_err}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; -use datafusion::datasource::MemTable; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use object_store::local::LocalFileSystem; -use std::path::Path; -use std::sync::Arc; /// Examples of various ways to execute queries using SQL /// /// [`query_memtable`]: a simple query against a [`MemTable`] /// [`query_parquet`]: a simple query against a directory with multiple Parquet files -/// -#[tokio::main] -async fn main() -> Result<()> { +pub async fn query() -> Result<()> { query_memtable().await?; query_parquet().await?; Ok(()) @@ -113,32 +114,33 @@ async fn query_parquet() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - // This is a workaround for this example since `test_data` contains - // many different parquet different files, - // in practice use FileType::PARQUET.get_ext(). - .with_file_extension("alltypes_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); + + let table_path = parquet_temp.file_uri()?; // First example were we use an absolute path, which requires no additional setup. ctx.register_listing_table( "my_table", - &format!("file://{test_data}/"), + &table_path, listing_options.clone(), None, None, ) - .await - .unwrap(); + .await?; // execute the query let df = ctx .sql( "SELECT * \ FROM my_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -147,20 +149,22 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], - &results); + &results + ); - // Second example were we temporarily move into the test data's parent directory and - // simulate a relative path, this requires registering an ObjectStore. + // Second example where we change the current working directory and explicitly + // register a local filesystem object store. This demonstrates how listing tables + // resolve paths via an ObjectStore, even when using filesystem-backed data. let cur_dir = std::env::current_dir()?; - - let test_data_path = Path::new(&test_data); - let test_data_path_parent = test_data_path + let test_data_path_parent = parquet_temp + .tmp_dir + .path() .parent() .ok_or(exec_datafusion_err!("test_data path needs a parent"))?; @@ -168,15 +172,15 @@ async fn query_parquet() -> Result<()> { let local_fs = Arc::new(LocalFileSystem::default()); - let u = url::Url::parse("file://./") + let url = url::Url::parse("file://./") .map_err(|e| DataFusionError::External(Box::new(e)))?; - ctx.register_object_store(&u, local_fs); + ctx.register_object_store(&url, local_fs); // Register a listing table - this will use all files in the directory as data sources // for the query ctx.register_listing_table( "relative_table", - "./data", + parquet_temp.path_str()?, listing_options.clone(), None, None, @@ -188,6 +192,7 @@ async fn query_parquet() -> Result<()> { .sql( "SELECT * \ FROM relative_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -196,13 +201,14 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], - &results); + &results + ); // Reset the current directory std::env::set_current_dir(cur_dir)?; diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/udf/advanced_udaf.rs similarity index 96% rename from datafusion-examples/examples/advanced_udaf.rs rename to datafusion-examples/examples/udf/advanced_udaf.rs index 89f0a470e32e4..f1651dbf28913 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/udf/advanced_udaf.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{Field, Schema}; use datafusion::physical_expr::NullState; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; use arrow::array::{ ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt32Array, @@ -26,13 +28,13 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; -use datafusion::common::{cast::as_float64_array, ScalarValue}; +use datafusion::common::{ScalarValue, cast::as_float64_array}; use datafusion::error::Result; use datafusion::logical_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, expr::AggregateFunction, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - simplify::SimplifyInfo, - Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, + simplify::SimplifyContext, }; use datafusion::prelude::*; @@ -62,11 +64,6 @@ impl GeoMeanUdaf { } impl AggregateUDFImpl for GeoMeanUdaf { - /// We implement as_any so that we can downcast the AggregateUDFImpl trait object - fn as_any(&self) -> &dyn Any { - self - } - /// Return the name of this function fn name(&self) -> &str { "geo_mean" @@ -312,12 +309,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { let prods = emit_to.take_needed(&mut self.prods); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), prods.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), counts.len()); + } assert_eq!(counts.len(), prods.len()); // don't evaluate geometric mean with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); let iter = prods.into_iter().zip(counts).zip(nulls.iter()); @@ -335,7 +336,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { .zip(counts) .map(|(prod, count)| prod.powf(1.0 / count as f64)) .collect::>(); - PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy + PrimitiveArray::new(geo_mean.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -345,7 +346,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { // return arrays for counts and prods fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy @@ -382,10 +382,6 @@ impl SimplifiedGeoMeanUdaf { } impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "simplified_geo_mean" } @@ -419,7 +415,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { /// Optionally replaces a UDAF with another expression during query optimization. fn simplify(&self) -> Option { - let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + let simplify = |aggregate_function: AggregateFunction, _: &SimplifyContext| { // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method. // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( @@ -469,8 +465,9 @@ fn create_context() -> Result { Ok(ctx) } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `GeoMeanUdaf` and `SimplifiedGeoMeanUdaf` +/// as user defined aggregate functions and invoke them via the DataFrame API and SQL +pub async fn advanced_udaf() -> Result<()> { let ctx = create_context()?; let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new()); diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/udf/advanced_udf.rs similarity index 98% rename from datafusion-examples/examples/advanced_udf.rs rename to datafusion-examples/examples/udf/advanced_udf.rs index 56ae599efa11b..d3815459dba52 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/udf/advanced_udf.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, AsArray, Float32Array, Float64Array, + Array, ArrayRef, AsArray, Float32Array, Float64Array, new_null_array, }; use arrow::compute; use arrow::datatypes::{DataType, Float64Type}; use arrow::record_batch::RecordBatch; -use datafusion::common::{exec_err, internal_err, ScalarValue}; +use datafusion::common::{ScalarValue, exec_err, internal_err}; use datafusion::error::Result; -use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::Volatility; +use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }; @@ -64,10 +65,6 @@ impl PowUdf { impl ScalarUDFImpl for PowUdf { /// We implement as_any so that we can downcast the ScalarUDFImpl trait object - fn as_any(&self) -> &dyn Any { - self - } - /// Return the name of this function fn name(&self) -> &str { "pow" @@ -245,10 +242,35 @@ fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { } } +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL -#[tokio::main] -async fn main() -> Result<()> { +pub async fn advanced_udf() -> Result<()> { let ctx = create_context()?; // create the UDF @@ -295,29 +317,3 @@ async fn main() -> Result<()> { Ok(()) } - -/// create local execution context with an in-memory table: -/// -/// ```text -/// +-----+-----+ -/// | a | b | -/// +-----+-----+ -/// | 2.1 | 1.0 | -/// | 3.1 | 2.0 | -/// | 4.1 | 3.0 | -/// | 5.1 | 4.0 | -/// +-----+-----+ -/// ``` -fn create_context() -> Result { - // define data. - let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); - let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; - - // declare a new context. In Spark API, this corresponds to a new SparkSession - let ctx = SessionContext::new(); - - // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; - Ok(ctx) -} diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/udf/advanced_udwf.rs similarity index 93% rename from datafusion-examples/examples/advanced_udwf.rs rename to datafusion-examples/examples/udf/advanced_udwf.rs index ba4c377fd6762..2508e6cd60e59 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/udf/advanced_udwf.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::sync::Arc; + use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, @@ -28,7 +32,7 @@ use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::function::{ PartitionEvaluatorArgs, WindowFunctionSimplification, WindowUDFFieldArgs, }; -use datafusion::logical_expr::simplify::SimplifyInfo; +use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ Expr, LimitEffect, PartitionEvaluator, Signature, WindowFrame, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -36,8 +40,7 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use std::any::Any; -use std::sync::Arc; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -65,11 +68,6 @@ impl SmoothItUdf { } impl WindowUDFImpl for SmoothItUdf { - /// We implement as_any so that we can downcast the WindowUDFImpl trait object - fn as_any(&self) -> &dyn Any { - self - } - /// Return the name of this function fn name(&self) -> &str { "smooth_it" @@ -173,10 +171,6 @@ impl SimplifySmoothItUdf { } } impl WindowUDFImpl for SimplifySmoothItUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "simplify_smooth_it" } @@ -195,7 +189,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `AggregateUDF` for `Avg` /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { - let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { + let simplify = |window_function: WindowFunction, _: &SimplifyContext| { Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { @@ -227,17 +221,17 @@ async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + let dataset = ExampleDataset::Cars; + + ctx.register_csv("cars", dataset.path_str()?, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `SmoothItUdf` as user defined window function +/// and invoke it via the DataFrame API and SQL +pub async fn advanced_udwf() -> Result<()> { let ctx = create_context().await?; let smooth_it = WindowUDF::from(SmoothItUdf::new()); ctx.register_udwf(smooth_it.clone()); diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/udf/async_udf.rs similarity index 81% rename from datafusion-examples/examples/async_udf.rs rename to datafusion-examples/examples/udf/async_udf.rs index b52ec68ea4422..43b82c398c5c6 100644 --- a/datafusion-examples/examples/async_udf.rs +++ b/datafusion-examples/examples/udf/async_udf.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to create and use "Async UDFs" in DataFusion. //! //! Async UDFs allow you to perform asynchronous operations, such as //! making network requests. This can be used for tasks like fetching //! data from an external API such as a LLM service or an external database. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; @@ -35,11 +39,10 @@ use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use std::any::Any; -use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `AskLLM` as an asynchronous user defined function +/// and invoke it via the DataFrame API and SQL +pub async fn async_udf() -> Result<()> { // Use a hard coded parallelism level of 4 so the explain plan // is consistent across machines. let config = SessionConfig::new().with_target_partitions(4); @@ -90,20 +93,18 @@ async fn main() -> Result<()> { assert_batches_eq!( [ - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", - "| logical_plan | SubqueryAlias: a |", - "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", - "| | TableScan: animal projection=[id, name] |", - "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |", - "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", - "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", - "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | DataSourceExec: partitions=1, partition_sizes=[1] |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | SubqueryAlias: a |", + "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", + "| | TableScan: animal projection=[id, name] |", + "| physical_plan | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", + "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", + "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", + "| | DataSourceExec: partitions=1, partition_sizes=[1] |", + "| | |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", ], &results ); @@ -159,10 +160,6 @@ impl AskLLM { /// information for the function, such as its name, signature, and return type. /// [async_trait] impl ScalarUDFImpl for AskLLM { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ask_llm" } diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs new file mode 100644 index 0000000000000..89f3fd801deec --- /dev/null +++ b/datafusion-examples/examples/udf/main.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # User-Defined Functions Examples +//! +//! These examples demonstrate user-defined functions in DataFusion. +//! +//! ## Usage +//! ```bash +//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf|table_list_udtf] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `adv_udaf` +//! (file: advanced_udaf.rs, desc: Advanced User Defined Aggregate Function (UDAF)) +//! +//! - `adv_udf` +//! (file: advanced_udf.rs, desc: Advanced User Defined Scalar Function (UDF)) +//! +//! - `adv_udwf` +//! (file: advanced_udwf.rs, desc: Advanced User Defined Window Function (UDWF)) +//! +//! - `async_udf` +//! (file: async_udf.rs, desc: Asynchronous User Defined Scalar Function) +//! +//! - `udaf` +//! (file: simple_udaf.rs, desc: Simple UDAF example) +//! +//! - `udf` +//! (file: simple_udf.rs, desc: Simple UDF example) +//! +//! - `udtf` +//! (file: simple_udtf.rs, desc: Simple UDTF example) +//! +//! - `udwf` +//! (file: simple_udwf.rs, desc: Simple UDWF example) +//! +//! - `table_list_udtf` +//! (file: table_list_udtf.rs, desc: Session-aware UDTF table list example) + +mod advanced_udaf; +mod advanced_udf; +mod advanced_udwf; +mod async_udf; +mod simple_udaf; +mod simple_udf; +mod simple_udtf; +mod simple_udwf; +mod table_list_udtf; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AdvUdaf, + AdvUdf, + AdvUdwf, + AsyncUdf, + Udf, + Udaf, + Udwf, + Udtf, + TableListUdtf, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "udf"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AdvUdaf => advanced_udaf::advanced_udaf().await?, + ExampleKind::AdvUdf => advanced_udf::advanced_udf().await?, + ExampleKind::AdvUdwf => advanced_udwf::advanced_udwf().await?, + ExampleKind::AsyncUdf => async_udf::async_udf().await?, + ExampleKind::Udaf => simple_udaf::simple_udaf().await?, + ExampleKind::Udf => simple_udf::simple_udf().await?, + ExampleKind::Udtf => simple_udtf::simple_udtf().await?, + ExampleKind::Udwf => simple_udwf::simple_udwf().await?, + ExampleKind::TableListUdtf => table_list_udtf::table_list_udtf().await?, + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/udf/simple_udaf.rs similarity index 96% rename from datafusion-examples/examples/simple_udaf.rs rename to datafusion-examples/examples/udf/simple_udaf.rs index 82bde7c034a57..42ea0054b759f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/udf/simple_udaf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean use datafusion::arrow::{ @@ -135,8 +137,9 @@ impl Accumulator for GeometricMean { } } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `GeometricMean` +/// as user defined aggregate function and invoke it via the DataFrame API and SQL +pub async fn simple_udaf() -> Result<()> { let ctx = create_context()?; // here is where we define the UDAF. We also declare its signature: diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/udf/simple_udf.rs similarity index 98% rename from datafusion-examples/examples/simple_udf.rs rename to datafusion-examples/examples/udf/simple_udf.rs index 5612e0939f709..e8d6c9c8173ac 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/udf/simple_udf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::{ arrow::{ array::{ArrayRef, Float32Array, Float64Array}, @@ -57,8 +59,7 @@ fn create_context() -> Result { } /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udf() -> Result<()> { let ctx = create_context()?; // First, declare the actual implementation of the calculation diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/udf/simple_udtf.rs similarity index 85% rename from datafusion-examples/examples/simple_udtf.rs rename to datafusion-examples/examples/udf/simple_udtf.rs index b65ffb8d71748..af123ab7e5d4a 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/udf/simple_udtf.rs @@ -15,53 +15,56 @@ // specific language governing permissions and limitations // under the License. -use arrow::csv::reader::Format; +//! See `main.rs` for how to run it. + +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::Session; -use datafusion::catalog::TableFunctionImpl; -use datafusion::common::{plan_err, ScalarValue}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl}; +use datafusion::common::{ScalarValue, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; -use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{Expr, TableType}; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; -use std::fs::File; -use std::io::Seek; -use std::path::Path; -use std::sync::Arc; +use datafusion_examples::utils::datasets::ExampleDataset; + // To define your own table function, you only need to do the following 3 things: // 1. Implement your own [`TableProvider`] // 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] // 3. Register the function using [`SessionContext::register_udtf`] /// This example demonstrates how to register a TableFunction -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udtf() -> Result<()> { // create local execution context let ctx = SessionContext::new(); // register the table function that will be called in SQL statements by `read_csv` ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); - let testdata = datafusion::test_util::arrow_test_data(); - let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + let dataset = ExampleDataset::Cars; // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .sql( + format!("SELECT * FROM read_csv('{}', 1 + 1);", dataset.path_str()?).as_str(), + ) .await?; df.show().await?; // just run, return all rows let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .sql(format!("SELECT * FROM read_csv('{}');", dataset.path_str()?).as_str()) .await?; df.show().await?; @@ -82,10 +85,6 @@ struct LocalCsvTable { #[async_trait] impl TableProvider for LocalCsvTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -132,9 +131,9 @@ impl TableProvider for LocalCsvTable { struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() - else { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); + let Some(Expr::Literal(ScalarValue::Utf8(Some(path)), _)) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; @@ -142,8 +141,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { .get(1) .map(|expr| { // try to simplify the expression, so 1+2 becomes 3, for example - let execution_props = ExecutionProps::new(); - let info = SimplifyContext::new(&execution_props); + let info = SimplifyContext::default(); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/udf/simple_udwf.rs similarity index 79% rename from datafusion-examples/examples/simple_udwf.rs rename to datafusion-examples/examples/udf/simple_udwf.rs index 1736ff00bd700..1842d88b9ba29 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/udf/simple_udwf.rs @@ -15,35 +15,70 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +//! See `main.rs` for how to run it. + +use std::{fs::File, io::Write, sync::Arc}; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::{DataType, Float64Type}, }; - use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::{PartitionEvaluator, Volatility, WindowFrame}; use datafusion::prelude::*; +use tempfile::tempdir; // create local execution context with `cars.csv` registered as a table named `cars` async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + // content from file 'datafusion/core/tests/data/cars.csv' + let csv_data = r#"car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 +"#; + let dir = tempdir()?; + let file_path = dir.path().join("cars.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all(csv_data.as_bytes())?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); + + ctx.register_csv("cars", file_path, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } /// In this example we will declare a user defined window function that computes a moving average and then run it using SQL -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udwf() -> Result<()> { let ctx = create_context().await?; // here is where we define the UDWF. We also declare its signature: diff --git a/datafusion-examples/examples/udf/table_list_udtf.rs b/datafusion-examples/examples/udf/table_list_udtf.rs new file mode 100644 index 0000000000000..739f8e11f07e5 --- /dev/null +++ b/datafusion-examples/examples/udf/table_list_udtf.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::sync::{Arc, LazyLock}; + +use arrow::array::{RecordBatch, StringBuilder}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + catalog::{MemTable, TableFunctionArgs, TableFunctionImpl, TableProvider}, + common::Result, + execution::SessionState, + prelude::SessionContext, +}; +use datafusion_common::{DataFusionError, plan_err}; +use tokio::{runtime::Handle, task::block_in_place}; + +const FUNCTION_NAME: &str = "table_list"; + +// The example shows, how to create UDTF that depends on the session state. +// Defines a `table_list` UDTF that returns a list of tables within the provided session. + +pub async fn table_list_udtf() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_udtf(FUNCTION_NAME, Arc::new(TableListUdtf)); + + // Register different kinds of tables. + ctx.sql("create view v as select 1") + .await? + .collect() + .await?; + ctx.sql("create table t(a int)").await?.collect().await?; + + // Print results. + ctx.sql("select * from table_list()").await?.show().await?; + + Ok(()) +} + +#[derive(Debug, Default)] +struct TableListUdtf; + +static SCHEMA: LazyLock = LazyLock::new(|| { + SchemaRef::new(Schema::new(vec![ + Field::new("catalog", DataType::Utf8, false), + Field::new("schema", DataType::Utf8, false), + Field::new("table", DataType::Utf8, false), + Field::new("type", DataType::Utf8, false), + ])) +}); + +impl TableFunctionImpl for TableListUdtf { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + if !args.exprs().is_empty() { + return plan_err!( + "{}: unexpected number of arguments: {}, expected: 0", + FUNCTION_NAME, + args.exprs().len() + ); + } + let state = args + .session() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("failed to downcast state".into()) + })?; + + let mut catalogs = StringBuilder::new(); + let mut schemas = StringBuilder::new(); + let mut tables = StringBuilder::new(); + let mut types = StringBuilder::new(); + + let catalog_list = state.catalog_list(); + for catalog_name in catalog_list.catalog_names() { + let Some(catalog) = catalog_list.catalog(&catalog_name) else { + continue; + }; + for schema_name in catalog.schema_names() { + let Some(schema) = catalog.schema(&schema_name) else { + continue; + }; + for table_name in schema.table_names() { + let Some(provider) = block_in_place(|| { + Handle::current().block_on(schema.table(&table_name)) + })? + else { + continue; + }; + catalogs.append_value(catalog_name.clone()); + schemas.append_value(schema_name.clone()); + tables.append_value(table_name.clone()); + types.append_value(provider.table_type().to_string()) + } + } + } + + let batch = RecordBatch::try_new( + Arc::clone(&SCHEMA), + vec![ + Arc::new(catalogs.finish()), + Arc::new(schemas.finish()), + Arc::new(tables.finish()), + Arc::new(types.finish()), + ], + )?; + + Ok(Arc::new(MemTable::try_new( + batch.schema(), + vec![vec![batch]], + )?)) + } +} diff --git a/datafusion-examples/src/bin/examples-docs.rs b/datafusion-examples/src/bin/examples-docs.rs new file mode 100644 index 0000000000000..7efcf4da15d20 --- /dev/null +++ b/datafusion-examples/src/bin/examples-docs.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generates Markdown documentation for DataFusion example groups. +//! +//! This binary scans `datafusion-examples/examples`, extracts structured +//! documentation from each group's `main.rs` file, and renders a README-style +//! Markdown document. +//! +//! By default, documentation is generated for all example groups. If a group +//! name is provided as the first CLI argument, only that group is rendered. +//! +//! ## Usage +//! +//! ```bash +//! # Generate docs for all example groups +//! cargo run --bin examples-docs +//! +//! # Generate docs for a single group +//! cargo run --bin examples-docs -- dataframe +//! ``` + +use datafusion_examples::utils::example_metadata::{ + RepoLayout, generate_examples_readme, +}; + +fn main() -> Result<(), Box> { + let layout = RepoLayout::detect()?; + let group = std::env::args().nth(1); + let markdown = generate_examples_readme(&layout, group.as_deref())?; + print!("{markdown}"); + Ok(()) +} diff --git a/datafusion-examples/src/lib.rs b/datafusion-examples/src/lib.rs new file mode 100644 index 0000000000000..7f334aedaafe2 --- /dev/null +++ b/datafusion-examples/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Internal utilities shared by the DataFusion examples. + +pub mod utils; diff --git a/datafusion-examples/src/utils/csv_to_parquet.rs b/datafusion-examples/src/utils/csv_to_parquet.rs new file mode 100644 index 0000000000000..1fbf2930e9043 --- /dev/null +++ b/datafusion-examples/src/utils/csv_to_parquet.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::{Path, PathBuf}; + +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use tempfile::TempDir; +use tokio::fs::create_dir_all; + +/// Temporary Parquet directory that is deleted when dropped. +#[derive(Debug)] +pub struct ParquetTemp { + pub tmp_dir: TempDir, + pub parquet_dir: PathBuf, +} + +impl ParquetTemp { + pub fn path(&self) -> &Path { + &self.parquet_dir + } + + pub fn path_str(&self) -> Result<&str> { + self.parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Parquet directory path is not valid UTF-8: {}", + self.parquet_dir.display() + )) + }) + } + + pub fn file_uri(&self) -> Result { + Ok(format!("file://{}", self.path_str()?)) + } +} + +/// Helper for examples: load a CSV file and materialize it as Parquet +/// in a temporary directory. +/// +/// # Example +/// ``` +/// use std::path::PathBuf; +/// use datafusion::prelude::*; +/// use datafusion_examples::utils::write_csv_to_parquet; +/// # use datafusion::assert_batches_eq; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let ctx = SessionContext::new(); +/// let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) +/// .join("data") +/// .join("csv") +/// .join("cars.csv"); +/// let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; +/// let df = ctx.read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()).await?; +/// let rows = df +/// .sort(vec![col("speed").sort(true, true)])? +/// .limit(0, Some(5))?; +/// assert_batches_eq!( +/// &[ +/// "+-------+-------+---------------------+", +/// "| car | speed | time |", +/// "+-------+-------+---------------------+", +/// "| red | 0.0 | 1996-04-12T12:05:15 |", +/// "| red | 1.0 | 1996-04-12T12:05:14 |", +/// "| green | 2.0 | 1996-04-12T12:05:14 |", +/// "| red | 3.0 | 1996-04-12T12:05:13 |", +/// "| red | 7.0 | 1996-04-12T12:05:10 |", +/// "+-------+-------+---------------------+", +/// ], +/// &rows.collect().await? +/// ); +/// # Ok(()) +/// # } +/// ``` +pub async fn write_csv_to_parquet( + ctx: &SessionContext, + csv_path: &Path, +) -> Result { + if !csv_path.is_file() { + return Err(DataFusionError::Execution(format!( + "CSV file does not exist: {}", + csv_path.display() + ))); + } + + let csv_path = csv_path.to_str().ok_or_else(|| { + DataFusionError::Execution("CSV path is not valid UTF-8".to_string()) + })?; + + let csv_df = ctx.read_csv(csv_path, CsvReadOptions::default()).await?; + + let tmp_dir = TempDir::new()?; + let parquet_dir = tmp_dir.path().join("parquet_source"); + create_dir_all(&parquet_dir).await?; + + let path = parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution("Failed processing tmp directory path".to_string()) + })?; + + csv_df + .write_parquet(path, DataFrameWriteOptions::default(), None) + .await?; + + Ok(ParquetTemp { + tmp_dir, + parquet_dir, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + + use datafusion::assert_batches_eq; + use datafusion::prelude::*; + + #[tokio::test] + async fn test_write_csv_to_parquet_with_cars_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("cars.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("speed").sort(true, true)])?; + assert_batches_eq!( + &[ + "+-------+-------+---------------------+", + "| car | speed | time |", + "+-------+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "| red | 1.0 | 1996-04-12T12:05:14 |", + "| green | 2.0 | 1996-04-12T12:05:14 |", + "| red | 3.0 | 1996-04-12T12:05:13 |", + "| red | 7.0 | 1996-04-12T12:05:10 |", + "| red | 7.1 | 1996-04-12T12:05:11 |", + "| red | 7.2 | 1996-04-12T12:05:12 |", + "| green | 8.0 | 1996-04-12T12:05:13 |", + "| green | 10.0 | 1996-04-12T12:05:03 |", + "| green | 10.3 | 1996-04-12T12:05:04 |", + "| green | 10.4 | 1996-04-12T12:05:05 |", + "| green | 10.5 | 1996-04-12T12:05:06 |", + "| green | 11.0 | 1996-04-12T12:05:07 |", + "| green | 12.0 | 1996-04-12T12:05:08 |", + "| green | 14.0 | 1996-04-12T12:05:09 |", + "| green | 15.0 | 1996-04-12T12:05:10 |", + "| green | 15.1 | 1996-04-12T12:05:11 |", + "| green | 15.2 | 1996-04-12T12:05:12 |", + "| red | 17.0 | 1996-04-12T12:05:09 |", + "| red | 18.0 | 1996-04-12T12:05:08 |", + "| red | 19.0 | 1996-04-12T12:05:07 |", + "| red | 20.0 | 1996-04-12T12:05:03 |", + "| red | 20.3 | 1996-04-12T12:05:04 |", + "| red | 21.4 | 1996-04-12T12:05:05 |", + "| red | 21.5 | 1996-04-12T12:05:06 |", + "+-------+-------+---------------------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_with_regex_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("regex.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("values").sort(true, true)])?; + assert_batches_eq!( + &[ + "+------------+--------------------------------------+-------------+-------+", + "| values | patterns | replacement | flags |", + "+------------+--------------------------------------+-------------+-------+", + "| 4000 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| 4010 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| ABC | ^(A).* | B | i |", + "| AbC | (B|D) | e | |", + "| Düsseldorf | [\\p{Letter}-]+ | München | |", + "| Köln | [a-zA-Z]ö[a-zA-Z]{2} | Koln | |", + "| aBC | ^(b|c) | d | |", + "| aBc | (b|d) | e | i |", + "| abc | ^(a) | bb\\1bb | i |", + "| Москва | [\\p{L}-]+ | Moscow | |", + "| اليوم | ^\\p{Arabic}+$ | Today | |", + "+------------+--------------------------------------+-------------+-------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_error() { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("file-does-not-exist.csv"); + + let err = write_csv_to_parquet(&ctx, &csv_path).await.unwrap_err(); + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains("CSV file does not exist"), + "unexpected error message: {msg}" + ); + } + other => panic!("unexpected error variant: {other:?}"), + } + } +} diff --git a/datafusion/datasource-avro/src/avro_to_arrow/mod.rs b/datafusion-examples/src/utils/datasets/cars.rs similarity index 58% rename from datafusion/datasource-avro/src/avro_to_arrow/mod.rs rename to datafusion-examples/src/utils/datasets/cars.rs index c1530a4880205..2d8547c16d686 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/mod.rs +++ b/datafusion-examples/src/utils/datasets/cars.rs @@ -15,25 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! This module contains code for reading [Avro] data into `RecordBatch`es -//! -//! [Avro]: https://avro.apache.org/docs/1.2.0/ +use std::sync::Arc; -mod arrow_array_reader; -mod reader; -mod schema; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use arrow::datatypes::Schema; -pub use reader::{Reader, ReaderBuilder}; - -pub use schema::to_arrow_schema; -use std::io::Read; - -/// Read Avro schema given a reader -pub fn read_avro_schema_from_reader( - reader: &mut R, -) -> datafusion_common::Result { - let avro_reader = apache_avro::Reader::new(reader)?; - let schema = avro_reader.writer_schema(); - to_arrow_schema(schema) +/// Schema for the `data/csv/cars.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("car", DataType::Utf8, false), + Field::new("speed", DataType::Float64, false), + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ])) } diff --git a/datafusion-examples/src/utils/datasets/mod.rs b/datafusion-examples/src/utils/datasets/mod.rs new file mode 100644 index 0000000000000..1857e6af9b559 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/mod.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; + +use arrow_schema::SchemaRef; +use datafusion::error::{DataFusionError, Result}; + +pub mod cars; +pub mod regex; + +/// Describes example datasets used across DataFusion examples. +/// +/// This enum provides a single, discoverable place to define +/// dataset-specific metadata such as file paths and schemas. +#[derive(Debug)] +pub enum ExampleDataset { + Cars, + Regex, +} + +impl ExampleDataset { + pub fn file_stem(&self) -> &'static str { + match self { + Self::Cars => "cars", + Self::Regex => "regex", + } + } + + pub fn path(&self) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join(format!("{}.csv", self.file_stem())) + } + + pub fn path_str(&self) -> Result { + let path = self.path(); + path.to_str().map(String::from).ok_or_else(|| { + DataFusionError::Execution(format!( + "CSV directory path is not valid UTF-8: {}", + path.display() + )) + }) + } + + pub fn schema(&self) -> SchemaRef { + match self { + Self::Cars => cars::schema(), + Self::Regex => regex::schema(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::{DataType, TimeUnit}; + + #[test] + fn example_dataset_file_stem() { + assert_eq!(ExampleDataset::Cars.file_stem(), "cars"); + assert_eq!(ExampleDataset::Regex.file_stem(), "regex"); + } + + #[test] + fn example_dataset_path_points_to_csv() { + let path = ExampleDataset::Cars.path(); + assert!(path.ends_with("data/csv/cars.csv")); + + let path = ExampleDataset::Regex.path(); + assert!(path.ends_with("data/csv/regex.csv")); + } + + #[test] + fn example_dataset_path_str_is_valid_utf8() { + let path = ExampleDataset::Cars.path_str().unwrap(); + assert!(path.ends_with("cars.csv")); + + let path = ExampleDataset::Regex.path_str().unwrap(); + assert!(path.ends_with("regex.csv")); + } + + #[test] + fn cars_schema_is_stable() { + let schema = ExampleDataset::Cars.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("car", DataType::Utf8), + ("speed", DataType::Float64), + ("time", DataType::Timestamp(TimeUnit::Nanosecond, None)), + ] + ); + } + + #[test] + fn regex_schema_is_stable() { + let schema = ExampleDataset::Regex.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("values", DataType::Utf8), + ("patterns", DataType::Utf8), + ("replacement", DataType::Utf8), + ("flags", DataType::Utf8), + ] + ); + } +} diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs b/datafusion-examples/src/utils/datasets/regex.rs similarity index 53% rename from datafusion/sqllogictest/src/engines/postgres_engine/types.rs rename to datafusion-examples/src/utils/datasets/regex.rs index 510462befb086..d44582126a053 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/types.rs +++ b/datafusion-examples/src/utils/datasets/regex.rs @@ -15,31 +15,16 @@ // specific language governing permissions and limitations // under the License. -use postgres_types::Type; -use std::fmt::Display; -use tokio_postgres::types::FromSql; +use std::sync::Arc; -pub struct PgRegtype { - value: String, -} - -impl<'a> FromSql<'a> for PgRegtype { - fn from_sql( - _: &Type, - buf: &'a [u8], - ) -> Result> { - let oid = postgres_protocol::types::oid_from_sql(buf)?; - let value = Type::from_oid(oid).ok_or("bad type")?.to_string(); - Ok(PgRegtype { value }) - } - - fn accepts(ty: &Type) -> bool { - matches!(*ty, Type::REGTYPE) - } -} +use arrow::datatypes::{DataType, Field, Schema}; -impl Display for PgRegtype { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) - } +/// Schema for the `data/csv/regex.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, false), + Field::new("patterns", DataType::Utf8, false), + Field::new("replacement", DataType::Utf8, false), + Field::new("flags", DataType::Utf8, true), + ])) } diff --git a/datafusion-examples/src/utils/example_metadata/discover.rs b/datafusion-examples/src/utils/example_metadata/discover.rs new file mode 100644 index 0000000000000..1ba5f6d29a14e --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/discover.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for discovering example groups in the repository filesystem. +//! +//! An example group is defined as a directory containing a `main.rs` file +//! under the examples root. This module is intentionally filesystem-focused +//! and does not perform any parsing or rendering. +//! Discovery fails if no valid example groups are found. + +use std::fs; +use std::path::{Path, PathBuf}; + +use datafusion::common::exec_err; +use datafusion::error::Result; + +/// Discovers all example group directories under the given root. +/// +/// A directory is considered an example group if it contains a `main.rs` file. +pub fn discover_example_groups(root: &Path) -> Result> { + let mut groups = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() && path.join("main.rs").is_file() { + groups.push(path); + } + } + + if groups.is_empty() { + return exec_err!("No example groups found under: {}", root.display()); + } + + groups.sort(); + Ok(groups) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs::{self, File}; + + use tempfile::TempDir; + + #[test] + fn discover_example_groups_finds_dirs_with_main_rs() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + + // valid example group + let group1 = root.join("group1"); + fs::create_dir(&group1)?; + File::create(group1.join("main.rs"))?; + + // not an example group + let group2 = root.join("group2"); + fs::create_dir(&group2)?; + + let groups = discover_example_groups(root)?; + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], group1); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_main_rs_is_a_directory() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + let group = root.join("group"); + fs::create_dir(&group)?; + fs::create_dir(group.join("main.rs"))?; + + let err = discover_example_groups(root).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_none_found() -> Result<()> { + let tmp = TempDir::new()?; + let err = discover_example_groups(tmp.path()).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/layout.rs b/datafusion-examples/src/utils/example_metadata/layout.rs new file mode 100644 index 0000000000000..ee6fad89855f9 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/layout.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Repository layout utilities. +//! +//! This module provides a small helper (`RepoLayout`) that encapsulates +//! knowledge about the DataFusion repository structure, in particular +//! where example groups are located relative to the repository root. + +use std::path::{Path, PathBuf}; + +use datafusion::error::{DataFusionError, Result}; + +/// Describes the layout of a DataFusion repository. +/// +/// This type centralizes knowledge about where example-related +/// directories live relative to the repository root. +#[derive(Debug, Clone)] +pub struct RepoLayout { + root: PathBuf, +} + +impl From<&Path> for RepoLayout { + fn from(path: &Path) -> Self { + Self { + root: path.to_path_buf(), + } + } +} + +impl RepoLayout { + /// Creates a layout from an explicit repository root. + pub fn from_root(root: PathBuf) -> Self { + Self { root } + } + + /// Detects the repository root based on `CARGO_MANIFEST_DIR`. + /// + /// This is intended for use from binaries inside the workspace. + pub fn detect() -> Result { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + let root = manifest_dir.parent().ok_or_else(|| { + DataFusionError::Execution( + "CARGO_MANIFEST_DIR does not have a parent".to_string(), + ) + })?; + + Ok(Self { + root: root.to_path_buf(), + }) + } + + /// Returns the repository root directory. + pub fn root(&self) -> &Path { + &self.root + } + + /// Returns the `datafusion-examples/examples` directory. + pub fn examples_root(&self) -> PathBuf { + self.root.join("datafusion-examples").join("examples") + } + + /// Returns the directory for a single example group. + /// + /// Example: `examples/udf` + pub fn example_group_dir(&self, group: &str) -> PathBuf { + self.examples_root().join(group) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detect_sets_non_empty_root() -> Result<()> { + let layout = RepoLayout::detect()?; + assert!(!layout.root().as_os_str().is_empty()); + Ok(()) + } + + #[test] + fn examples_root_is_under_repo_root() -> Result<()> { + let layout = RepoLayout::detect()?; + let examples_root = layout.examples_root(); + assert!(examples_root.starts_with(layout.root())); + assert!(examples_root.ends_with("datafusion-examples/examples")); + Ok(()) + } + + #[test] + fn example_group_dir_appends_group_name() -> Result<()> { + let layout = RepoLayout::detect()?; + let group_dir = layout.example_group_dir("foo"); + assert!(group_dir.ends_with("datafusion-examples/examples/foo")); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/mod.rs b/datafusion-examples/src/utils/example_metadata/mod.rs new file mode 100644 index 0000000000000..ab4c8e4a8e4c2 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/mod.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Documentation generator for DataFusion examples. +//! +//! # Design goals +//! +//! - Keep README.md in sync with runnable examples +//! - Fail fast on malformed documentation +//! +//! # Overview +//! +//! Each example group corresponds to a directory under +//! `datafusion-examples/examples/` containing a `main.rs` file. +//! Documentation is extracted from structured `//!` comments in that file. +//! +//! For each example group, the generator produces: +//! +//! ```text +//! ## Examples +//! ### Group: `` +//! #### Category: Single Process | Distributed +//! +//! | Subcommand | File Path | Description | +//! ``` +//! +//! # Usage +//! +//! Generate documentation for a single group only: +//! +//! ```bash +//! cargo run --bin examples-docs -- dataframe +//! ``` +//! +//! Generate documentation for all examples: +//! +//! ```bash +//! cargo run --bin examples-docs +//! ``` + +pub mod discover; +pub mod layout; +pub mod model; +pub mod parser; +pub mod render; + +#[cfg(test)] +pub mod test_utils; + +pub use layout::RepoLayout; +pub use model::{Category, ExampleEntry, ExampleGroup, GroupName}; +pub use parser::parse_main_rs_docs; +pub use render::generate_examples_readme; diff --git a/datafusion-examples/src/utils/example_metadata/model.rs b/datafusion-examples/src/utils/example_metadata/model.rs new file mode 100644 index 0000000000000..11416d141eb74 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/model.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Domain model for DataFusion example documentation. +//! +//! This module defines the core data structures used to represent +//! example groups, individual examples, and their categorization +//! as parsed from `main.rs` documentation comments. + +use std::path::Path; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::parse_main_rs_docs; + +/// Well-known abbreviations used to preserve correct capitalization +/// when generating human-readable documentation titles. +const ABBREVIATIONS: &[(&str, &str)] = &[ + ("dataframe", "DataFrame"), + ("io", "IO"), + ("sql", "SQL"), + ("udf", "UDF"), +]; + +/// A group of related examples (e.g. `builtin_functions`, `udf`). +/// +/// Each group corresponds to a directory containing a `main.rs` file +/// with structured documentation comments. +#[derive(Debug)] +pub struct ExampleGroup { + pub name: GroupName, + pub examples: Vec, + pub category: Category, +} + +impl ExampleGroup { + /// Parses an example group from its directory. + /// + /// The group name is derived from the directory name, and example + /// entries are extracted from `main.rs`. + pub fn from_dir(dir: &Path, category: Category) -> Result { + let raw_name = dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })? + .to_string(); + + let name = GroupName::from_dir_name(raw_name); + let main_rs = dir.join("main.rs"); + let examples = parse_main_rs_docs(&main_rs)?; + + Ok(Self { + name, + examples, + category, + }) + } +} + +/// Represents an example group name in both raw and human-readable forms. +/// +/// For example: +/// - raw: `builtin_functions` +/// - title: `Builtin Functions` +#[derive(Debug)] +pub struct GroupName { + raw: String, + title: String, +} + +impl GroupName { + /// Creates a group name from a directory name. + pub fn from_dir_name(raw: String) -> Self { + let title = raw + .split('_') + .map(format_part) + .collect::>() + .join(" "); + + Self { raw, title } + } + + /// Returns the raw group name (directory name). + pub fn raw(&self) -> &str { + &self.raw + } + + /// Returns a title-cased name for documentation. + pub fn title(&self) -> &str { + &self.title + } +} + +/// A single runnable example within a group. +/// +/// Each entry corresponds to a subcommand documented in `main.rs`. +#[derive(Debug)] +pub struct ExampleEntry { + /// CLI subcommand name. + pub subcommand: String, + /// Rust source file name. + pub file: String, + /// Human-readable description. + pub desc: String, +} + +/// Execution category of an example group. +#[derive(Debug, Default)] +pub enum Category { + /// Runs in a single process. + #[default] + SingleProcess, + /// Requires a distributed setup. + Distributed, +} + +impl Category { + /// Returns the display name used in documentation. + pub fn name(&self) -> &str { + match self { + Self::SingleProcess => "Single Process", + Self::Distributed => "Distributed", + } + } + + /// Determines the category for a group by name. + pub fn for_group(name: &str) -> Self { + match name { + "flight" => Category::Distributed, + _ => Category::SingleProcess, + } + } +} + +/// Formats a single group-name segment for display. +/// +/// This function applies DataFusion-specific capitalization rules: +/// - Known abbreviations (e.g. `sql`, `io`, `udf`) are rendered in all caps +/// - All other segments fall back to standard Title Case +fn format_part(part: &str) -> String { + let lower = part.to_ascii_lowercase(); + + if let Some((_, replacement)) = ABBREVIATIONS.iter().find(|(k, _)| *k == lower) { + return replacement.to_string(); + } + + let mut chars = part.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::{ + assert_exec_err_contains, example_group_from_docs, + }; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn category_for_group_works() { + assert!(matches!( + Category::for_group("flight"), + Category::Distributed + )); + assert!(matches!( + Category::for_group("anything_else"), + Category::SingleProcess + )); + } + + #[test] + fn all_subcommand_is_ignored() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `all` — run all examples included in this module + //! + //! - `foo` + //! (file: foo.rs, desc: foo example) + "#, + )?; + assert_eq!(group.examples.len(), 1); + assert_eq!(group.examples[0].subcommand, "foo"); + Ok(()) + } + + #[test] + fn metadata_without_subcommand_fails() { + let err = example_group_from_docs("//! (file: foo.rs, desc: missing subcommand)") + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn group_name_handles_abbreviations() { + assert_eq!( + GroupName::from_dir_name("dataframe".to_string()).title(), + "DataFrame" + ); + assert_eq!( + GroupName::from_dir_name("data_io".to_string()).title(), + "Data IO" + ); + assert_eq!( + GroupName::from_dir_name("sql_ops".to_string()).title(), + "SQL Ops" + ); + assert_eq!(GroupName::from_dir_name("udf".to_string()).title(), "UDF"); + } + + #[test] + fn group_name_title_cases() { + let cases = [ + ("very_long_group_name", "Very Long Group Name"), + ("foo", "Foo"), + ("dataframe", "DataFrame"), + ("data_io", "Data IO"), + ("sql_ops", "SQL Ops"), + ("udf", "UDF"), + ]; + for (input, expected) in cases { + let name = GroupName::from_dir_name(input.to_string()); + assert_eq!(name.title(), expected); + } + } + + #[test] + fn parse_group_example_works() -> Result<()> { + let tmp = TempDir::new().unwrap(); + + // Simulate: examples/builtin_functions/ + let group_dir = tmp.path().join("builtin_functions"); + fs::create_dir(&group_dir)?; + + // Write a fake main.rs with docs + let main_rs = group_dir.join("main.rs"); + fs::write( + &main_rs, + r#" + // Licensed to the Apache Software Foundation (ASF) under one + // or more contributor license agreements. See the NOTICE file + // distributed with this work for additional information + // regarding copyright ownership. The ASF licenses this file + // to you under the Apache License, Version 2.0 (the + // "License"); you may not use this file except in compliance + // with the License. You may obtain a copy of the License at + // + // http://www.apache.org/licenses/LICENSE-2.0 + // + // Unless required by applicable law or agreed to in writing, + // software distributed under the License is distributed on an + // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + // KIND, either express or implied. See the License for the + // specific language governing permissions and limitations + // under the License. + // + //! # These are miscellaneous function-related examples + //! + //! These examples demonstrate miscellaneous function-related features. + //! + //! ## Usage + //! ```bash + //! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] + //! ``` + //! + //! Each subcommand runs a corresponding example: + //! - `all` — run all examples included in this module + //! + //! - `date_time` + //! (file: date_time.rs, desc: Examples of date-time related functions and queries) + //! + //! - `function_factory` + //! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) + //! + //! - `regexp` + //! (file: regexp.rs, desc: Examples of using regular expression functions) + "#, + )?; + + let group = ExampleGroup::from_dir(&group_dir, Category::SingleProcess)?; + + // Assert group-level data + assert_eq!(group.name.title(), "Builtin Functions"); + assert_eq!(group.examples.len(), 3); + + // Assert 1 example + assert_eq!(group.examples[0].subcommand, "date_time"); + assert_eq!(group.examples[0].file, "date_time.rs"); + assert_eq!( + group.examples[0].desc, + "Examples of date-time related functions and queries" + ); + + // Assert 2 example + assert_eq!(group.examples[1].subcommand, "function_factory"); + assert_eq!(group.examples[1].file, "function_factory.rs"); + assert_eq!( + group.examples[1].desc, + "Register `CREATE FUNCTION` handler to implement SQL macros" + ); + + // Assert 3 example + assert_eq!(group.examples[2].subcommand, "regexp"); + assert_eq!(group.examples[2].file, "regexp.rs"); + assert_eq!( + group.examples[2].desc, + "Examples of using regular expression functions" + ); + + Ok(()) + } + + #[test] + fn duplicate_metadata_without_repeating_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn duplicate_metadata_for_same_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! + //! - `foo` + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Duplicate metadata for subcommand `foo`"); + } + + #[test] + fn metadata_must_follow_subcommand() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! some unrelated comment + //! (file: foo.rs, desc: test) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn preserves_example_order_from_main_rs() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `second` + //! (file: second.rs, desc: second example) + //! + //! - `first` + //! (file: first.rs, desc: first example) + //! + //! - `third` + //! (file: third.rs, desc: third example) + "#, + )?; + + let subcommands: Vec<&str> = group + .examples + .iter() + .map(|e| e.subcommand.as_str()) + .collect(); + + assert_eq!( + subcommands, + vec!["second", "first", "third"], + "examples must preserve the order defined in main.rs" + ); + + Ok(()) + } + + #[test] + fn metadata_can_follow_blank_doc_line() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `foo` + //! + //! (file: foo.rs, desc: test) + "#, + )?; + assert_eq!(group.examples.len(), 1); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/parser.rs b/datafusion-examples/src/utils/example_metadata/parser.rs new file mode 100644 index 0000000000000..4ead3e5a2ae9f --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/parser.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parser for example metadata embedded in `main.rs` documentation comments. +//! +//! This module scans `//!` doc comments to extract example subcommands +//! and their associated metadata (file name and description), enforcing +//! a strict ordering and structure to avoid ambiguous documentation. + +use std::{collections::HashSet, fs, path::Path}; + +use datafusion::common::exec_err; +use datafusion::error::Result; +use nom::{ + Err, IResult, Parser, + bytes::complete::{tag, take_until, take_while}, + character::complete::multispace0, + combinator::all_consuming, + error::{Error, ErrorKind}, + sequence::{delimited, preceded}, +}; + +use crate::utils::example_metadata::ExampleEntry; + +/// Parsing state machine used while scanning `main.rs` docs. +/// +/// This makes the "subcommand - metadata" relationship explicit: +/// metadata is only valid immediately after a subcommand has been seen. +enum ParserState<'a> { + /// Not currently expecting metadata. + Idle, + /// A subcommand was just parsed; the next valid metadata (if any) + /// must belong to this subcommand. + SeenSubcommand(&'a str), +} + +/// Parses a subcommand declaration line from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! - `` +/// ``` +fn parse_subcommand_line(input: &str) -> IResult<&str, &str> { + let parser = preceded( + multispace0, + delimited(tag("//! - `"), take_until("`"), tag("`")), + ); + all_consuming(parser).parse(input) +} + +/// Parses example metadata (file name and description) from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! (file: .rs, desc: ) +/// ``` +fn parse_metadata_line(input: &str) -> IResult<&str, (&str, &str)> { + let parser = preceded( + multispace0, + preceded(tag("//!"), preceded(multispace0, take_while(|_| true))), + ); + let (rest, payload) = all_consuming(parser).parse(input)?; + + let content = payload + .strip_prefix("(") + .and_then(|s| s.strip_suffix(")")) + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + let (file, desc) = content + .strip_prefix("file:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))? + .split_once(", desc:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + Ok((rest, (file.trim(), desc.trim()))) +} + +/// Parses example entries from a group's `main.rs` file. +pub fn parse_main_rs_docs(path: &Path) -> Result> { + let content = fs::read_to_string(path)?; + let mut entries = vec![]; + let mut state = ParserState::Idle; + let mut seen_subcommands = HashSet::new(); + + for (line_no, raw_line) in content.lines().enumerate() { + let line = raw_line.trim(); + + // Try parsing subcommand, excluding `all` because it's not used in README + if let Ok((_, sub)) = parse_subcommand_line(line) { + state = if sub == "all" { + ParserState::Idle + } else { + ParserState::SeenSubcommand(sub) + }; + continue; + } + + // Try parsing metadata + if let Ok((_, (file, desc))) = parse_metadata_line(line) { + let subcommand = match state { + ParserState::SeenSubcommand(s) => s, + ParserState::Idle => { + return exec_err!( + "Metadata without preceding subcommand at {}:{}", + path.display(), + line_no + 1 + ); + } + }; + + if !seen_subcommands.insert(subcommand) { + return exec_err!("Duplicate metadata for subcommand `{subcommand}`"); + } + + entries.push(ExampleEntry { + subcommand: subcommand.to_string(), + file: file.to_string(), + desc: desc.to_string(), + }); + + state = ParserState::Idle; + continue; + } + + // If a non-blank doc line interrupts a pending subcommand, reset the state + if let ParserState::SeenSubcommand(_) = state + && is_non_blank_doc_line(line) + { + state = ParserState::Idle; + } + } + + Ok(entries) +} + +/// Returns `true` for non-blank Rust doc comment lines (`//!`). +/// +/// Used to detect when a subcommand is interrupted by unrelated documentation, +/// so metadata is only accepted immediately after a subcommand (blank doc lines +/// are allowed in between). +fn is_non_blank_doc_line(line: &str) -> bool { + line.starts_with("//!") && !line.trim_start_matches("//!").trim().is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::TempDir; + + #[test] + fn parse_subcommand_line_accepts_valid_input() { + let line = "//! - `date_time`"; + let sub = parse_subcommand_line(line); + assert_eq!(sub, Ok(("", "date_time"))); + } + + #[test] + fn parse_subcommand_line_invalid_inputs() { + let err_lines = [ + "//! - ", + "//! - foo", + "//! - `foo` bar", + "//! --", + "//!-", + "//!--", + "//!", + "//", + "/", + "", + ]; + for line in err_lines { + assert!( + parse_subcommand_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_metadata_line_accepts_valid_input() { + let line = + "//! (file: date_time.rs, desc: Examples of date-time related functions)"; + let res = parse_metadata_line(line); + assert_eq!( + res, + Ok(( + "", + ("date_time.rs", "Examples of date-time related functions") + )) + ); + + let line = "//! (file: foo.rs, desc: Foo, bar, baz)"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo, bar, baz")))); + + let line = "//! (file: foo.rs, desc: Foo(FOO))"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo(FOO)")))); + } + + #[test] + fn parse_metadata_line_invalid_inputs() { + let bad_lines = [ + "//! (file: foo.rs)", + "//! (desc: missing file)", + "//! file: foo.rs, desc: test", + "//! file: foo.rs,desc: test", + "//! (file: foo.rs desc: test)", + "//! (file: foo.rs,desc: test)", + "//! (desc: test, file: foo.rs)", + "//! ()", + "//! (file: foo.rs, desc: test) extra", + "", + ]; + for line in bad_lines { + assert!( + parse_metadata_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_main_rs_docs_extracts_entries() -> Result<()> { + let tmp = TempDir::new().unwrap(); + let main_rs = tmp.path().join("main.rs"); + + fs::write( + &main_rs, + r#" + //! - `foo` + //! (file: foo.rs, desc: first example) + //! + //! - `bar` + //! (file: bar.rs, desc: second example) + "#, + )?; + + let entries = parse_main_rs_docs(&main_rs)?; + + assert_eq!(entries.len(), 2); + + assert_eq!(entries[0].subcommand, "foo"); + assert_eq!(entries[0].file, "foo.rs"); + assert_eq!(entries[0].desc, "first example"); + + assert_eq!(entries[1].subcommand, "bar"); + assert_eq!(entries[1].file, "bar.rs"); + assert_eq!(entries[1].desc, "second example"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/render.rs b/datafusion-examples/src/utils/example_metadata/render.rs new file mode 100644 index 0000000000000..a4ea620e78352 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/render.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Markdown renderer for DataFusion example documentation. +//! +//! This module takes parsed example metadata and generates the +//! `README.md` content for `datafusion-examples`, including group +//! sections and example tables. + +use std::path::PathBuf; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::discover::discover_example_groups; +use crate::utils::example_metadata::model::ExampleGroup; +use crate::utils::example_metadata::{Category, RepoLayout}; + +const STATIC_HEADER: &str = r#" + +# DataFusion Examples + +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. + +## Prerequisites + +Run `git submodule update --init` to init test files. + +## Running Examples + +To run an example, use the `cargo run` command, such as: + +```bash +git clone https://github.com/apache/datafusion +cd datafusion +# Download test data +git submodule update --init + +# Change to the examples directory +cd datafusion-examples/examples + +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe +``` +"#; + +/// Generates Markdown documentation for DataFusion examples. +/// +/// If `group` is `None`, documentation is generated for all example groups. +/// If `group` is `Some`, only that group is rendered. +/// +/// # Errors +/// +/// Returns an error if: +/// - the requested group does not exist +/// - a `main.rs` file is missing +/// - documentation comments are malformed +pub fn generate_examples_readme( + layout: &RepoLayout, + group: Option<&str>, +) -> Result { + let examples_root = layout.examples_root(); + + let mut out = String::new(); + out.push_str(STATIC_HEADER); + + let group_dirs: Vec = match group { + Some(name) => { + let dir = examples_root.join(name); + if !dir.is_dir() { + return Err(DataFusionError::Execution(format!( + "Example group `{name}` does not exist" + ))); + } + vec![dir] + } + None => discover_example_groups(&examples_root)?, + }; + + for group_dir in group_dirs { + let raw_name = + group_dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })?; + + let category = Category::for_group(raw_name); + let group = ExampleGroup::from_dir(&group_dir, category)?; + + out.push_str(&group.render_markdown()); + } + + Ok(out) +} + +impl ExampleGroup { + /// Renders this example group as a Markdown section for the README. + pub fn render_markdown(&self) -> String { + let mut out = String::new(); + out.push_str(&format!("\n## {} Examples\n\n", self.name.title())); + out.push_str(&format!("### Group: `{}`\n\n", self.name.raw())); + out.push_str(&format!("#### Category: {}\n\n", self.category.name())); + out.push_str("| Subcommand | File Path | Description |\n"); + out.push_str("| --- | --- | --- |\n"); + + for example in &self.examples { + out.push_str(&format!( + "| {} | [`{}/{}`](examples/{}/{}) | {} |\n", + example.subcommand, + self.name.raw(), + example.file, + self.name.raw(), + example.file, + example.desc + )); + } + + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn single_group_generation_works() { + let tmp = TempDir::new().unwrap(); + // Fake repo root + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + + // Create: datafusion-examples/examples/builtin_functions + let examples_dir = layout.example_group_dir("builtin_functions"); + fs::create_dir_all(&examples_dir).unwrap(); + + fs::write( + examples_dir.join("main.rs"), + "//! - `x`\n//! (file: foo.rs, desc: test)", + ) + .unwrap(); + + let out = generate_examples_readme(&layout, Some("builtin_functions")).unwrap(); + assert!(out.contains("Builtin Functions")); + assert!(out.contains("| x | [`builtin_functions/foo.rs`]")); + } + + #[test] + fn single_group_generation_fails_if_group_missing() { + let tmp = TempDir::new().unwrap(); + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + let err = generate_examples_readme(&layout, Some("missing_group")).unwrap_err(); + assert_exec_err_contains(err, "Example group `missing_group` does not exist"); + } +} diff --git a/datafusion-examples/src/utils/example_metadata/test_utils.rs b/datafusion-examples/src/utils/example_metadata/test_utils.rs new file mode 100644 index 0000000000000..d6ab3b06ba06d --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/test_utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test helpers for example metadata parsing and validation. +//! +//! This module provides small, focused utilities to reduce duplication +//! and keep tests readable across the example metadata submodules. + +use std::fs; + +use datafusion::error::{DataFusionError, Result}; +use tempfile::TempDir; + +use crate::utils::example_metadata::{Category, ExampleGroup}; + +/// Asserts that an `Execution` error contains the expected message fragment. +/// +/// Keeps tests focused on semantic error causes without coupling them +/// to full error string formatting. +pub fn assert_exec_err_contains(err: DataFusionError, needle: &str) { + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains(needle), + "expected '{needle}' in error message, got: {msg}" + ); + } + other => panic!("expected Execution error, got: {other:?}"), + } +} + +/// Helper for grammar-focused tests. +/// +/// Creates a minimal temporary example group with a single `main.rs` +/// containing the provided docs. Intended for testing parsing and +/// validation rules, not full integration behavior. +pub fn example_group_from_docs(docs: &str) -> Result { + let tmp = TempDir::new().map_err(|e| { + DataFusionError::Execution(format!("Failed initializing temp dir: {e}")) + })?; + let dir = tmp.path().join("group"); + fs::create_dir(&dir).map_err(|e| { + DataFusionError::Execution(format!("Failed creating temp dir: {e}")) + })?; + fs::write(dir.join("main.rs"), docs).map_err(|e| { + DataFusionError::Execution(format!("Failed writing to temp file: {e}")) + })?; + ExampleGroup::from_dir(&dir, Category::SingleProcess) +} diff --git a/datafusion-examples/src/utils/mod.rs b/datafusion-examples/src/utils/mod.rs new file mode 100644 index 0000000000000..da96724a49cb3 --- /dev/null +++ b/datafusion-examples/src/utils/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod csv_to_parquet; +pub mod datasets; +pub mod example_metadata; + +pub use csv_to_parquet::write_csv_to_parquet; diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 4eaeed675a206..61b55397137df 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -46,11 +46,14 @@ futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -tokio = { workspace = true } [dev-dependencies] +chrono = { workspace = true } datafusion-datasource-parquet = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true diff --git a/datafusion/catalog-listing/src/config.rs b/datafusion/catalog-listing/src/config.rs index 3370d2ea75535..ca4d2abfcd737 100644 --- a/datafusion/catalog-listing/src/config.rs +++ b/datafusion/catalog-listing/src/config.rs @@ -19,9 +19,10 @@ use crate::options::ListingOptions; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::{config_err, internal_err}; +use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_compression_type::FileCompressionType; +#[expect(deprecated)] use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::ListingTableUrl; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use std::str::FromStr; use std::sync::Arc; @@ -44,15 +45,12 @@ pub enum SchemaSource { /// # Schema Evolution Support /// /// This configuration supports schema evolution through the optional -/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: +/// [`PhysicalExprAdapterFactory`]. You might want to override the default factory when you need: /// /// - **Type coercion requirements**: When you need custom logic for converting between /// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) /// - **Column mapping**: You need to map columns with a legacy name to a new name /// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. -/// -/// If not specified, a [`datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory`] -/// will be used, which handles basic schema compatibility cases. #[derive(Debug, Clone, Default)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating [`crate::ListingTable`]. @@ -68,8 +66,6 @@ pub struct ListingTableConfig { pub options: Option, /// Tracks the source of the schema information pub(crate) schema_source: SchemaSource, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - pub(crate) schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters pub(crate) expr_adapter_factory: Option>, } @@ -218,8 +214,7 @@ impl ListingTableConfig { file_schema, options: _, schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, } = self; let (schema, new_schema_source) = match file_schema { @@ -241,8 +236,7 @@ impl ListingTableConfig { file_schema: Some(schema), options: Some(options), schema_source: new_schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, }) } None => internal_err!("No `ListingOptions` set for inferring schema"), @@ -282,7 +276,6 @@ impl ListingTableConfig { file_schema: self.file_schema, options: Some(options), schema_source: self.schema_source, - schema_adapter_factory: self.schema_adapter_factory, expr_adapter_factory: self.expr_adapter_factory, }) } @@ -290,63 +283,11 @@ impl ListingTableConfig { } } - /// Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] - /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. - /// - /// If not provided, a default schema adapter factory will be used. - /// - /// # Example: Custom Schema Adapter for Type Coercion - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # - /// # #[derive(Debug)] - /// # struct MySchemaAdapterFactory; - /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { - /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { - /// # unimplemented!() - /// # } - /// # } - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(listing_options) - /// .with_schema(table_schema) - /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); - /// ``` - pub fn with_schema_adapter_factory( - self, - schema_adapter_factory: Arc, - ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this configuration - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set the [`PhysicalExprAdapterFactory`] for the [`crate::ListingTable`] /// /// The expression adapter factory is used to create physical expression adapters that can /// handle schema evolution and type conversions when evaluating expressions /// with different schemas than the table schema. - /// - /// If not provided, a default physical expression adapter factory will be used unless a custom - /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. - /// - /// See for details on this transition. pub fn with_expr_adapter_factory( self, expr_adapter_factory: Arc, @@ -356,4 +297,23 @@ impl ListingTableConfig { ..self } } + + /// Deprecated: Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] + /// + /// `SchemaAdapterFactory` has been removed. Use [`Self::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. + /// + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn with_schema_adapter_factory( + self, + _schema_adapter_factory: Arc, + ) -> Self { + // No-op - just return self unchanged + self + } } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939e..c6305c30008ce 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -21,25 +21,23 @@ use std::mem; use std::sync::Arc; use datafusion_catalog::Session; -use datafusion_common::internal_err; -use datafusion_common::{HashMap, Result, ScalarValue}; +use datafusion_common::{HashMap, Result, ScalarValue, assert_or_internal_err}; use datafusion_datasource::ListingTableUrl; use datafusion_datasource::PartitionedFile; -use datafusion_expr::{BinaryExpr, Operator}; +use datafusion_expr::{BinaryExpr, Operator, lit, utils}; use arrow::{ - array::{Array, ArrayRef, AsArray, StringBuilder}, - compute::{and, cast, prep_null_mask_filter}, - datatypes::{DataType, Field, Fields, Schema}, + array::AsArray, + datatypes::{DataType, Field}, record_batch::RecordBatch, }; use datafusion_expr::execution_props::ExecutionProps; use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, DataFusionError}; +use datafusion_common::{Column, DFSchema}; use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; @@ -53,7 +51,7 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { + Expr::Column(Column { name, .. }) => { is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) @@ -85,13 +83,28 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } + Expr::HigherOrderFunction(hof) => { + match hof.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 Volatility::Stable | Volatility::Volatile => { is_applicable = false; Ok(TreeNodeRecursion::Stop) @@ -103,6 +116,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction { .. } @@ -239,105 +253,6 @@ pub async fn list_partitions( Ok(out) } -async fn prune_partitions( - table_path: &ListingTableUrl, - partitions: Vec, - filters: &[Expr], - partition_cols: &[(String, DataType)], -) -> Result> { - if filters.is_empty() { - // prune partitions which don't contain the partition columns - return Ok(partitions - .into_iter() - .filter(|p| { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - !parse_partitions_for_path(table_path, &p.path, cols) - .unwrap_or_default() - .is_empty() - }) - .collect()); - } - - let mut builders: Vec<_> = (0..partition_cols.len()) - .map(|_| StringBuilder::with_capacity(partitions.len(), partitions.len() * 10)) - .collect(); - - for partition in &partitions { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols) - .unwrap_or_default(); - - let mut builders = builders.iter_mut(); - for (p, b) in parsed.iter().zip(&mut builders) { - b.append_value(p); - } - builders.for_each(|b| b.append_null()); - } - - let arrays = partition_cols - .iter() - .zip(builders) - .map(|((_, d), mut builder)| { - let array = builder.finish(); - cast(&array, d) - }) - .collect::>()?; - - let fields: Fields = partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - let df_schema = DFSchema::from_unqualified_fields( - partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(), - Default::default(), - )?; - - let batch = RecordBatch::try_new(schema, arrays)?; - - // TODO: Plumb this down - let props = ExecutionProps::new(); - - // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Result { - let expr = create_physical_expr(filter, &df_schema, &props)?; - expr.evaluate(&batch)?.into_array(partitions.len()) - }; - - //.Compute the conjunction of the filters - let mask = filters - .iter() - .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) - .reduce(|a, b| Ok(and(&a?, &b?)?)); - - let mask = match mask { - Some(Ok(mask)) => mask, - Some(Err(err)) => return Err(err), - None => return Ok(partitions), - }; - - // Don't retain partitions that evaluated to null - let prepared = match mask.null_count() { - 0 => mask, - _ => prep_null_mask_filter(&mask), - }; - - // Sanity check - assert_eq!(prepared.len(), partitions.len()); - - let filtered = partitions - .into_iter() - .zip(prepared.values()) - .filter_map(|(p, f)| f.then_some(p)) - .collect(); - - Ok(filtered) -} - #[derive(Debug)] enum PartitionValue { Single(String), @@ -348,16 +263,11 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter - { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) - | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -412,6 +322,70 @@ pub fn evaluate_partition_prefix<'a>( } } +fn filter_partitions( + pf: PartitionedFile, + filters: &[Expr], + df_schema: &DFSchema, +) -> Result> { + if pf.partition_values.is_empty() && !filters.is_empty() { + return Ok(None); + } else if filters.is_empty() { + return Ok(Some(pf)); + } + + let arrays = pf + .partition_values + .iter() + .map(|v| v.to_array()) + .collect::>()?; + + let batch = RecordBatch::try_new(Arc::clone(df_schema.inner()), arrays)?; + + let filter = utils::conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)); + let props = ExecutionProps::new(); + let expr = create_physical_expr(&filter, df_schema, &props)?; + + // Since we're only operating on a single file, our batch and resulting "array" holds only one + // value indicating if the input file matches the provided filters + let matches = expr.evaluate(&batch)?.into_array(1)?; + if matches.as_boolean().value(0) { + return Ok(Some(pf)); + } + + Ok(None) +} + +/// Returns `Ok(None)` when the file is not inside a valid partition path +/// (e.g. a stale file in the table root directory). Such files are skipped +/// because hive-style partition values are never null and there is no valid +/// value to assign for non-partitioned files. +fn try_into_partitioned_file( + object_meta: ObjectMeta, + partition_cols: &[(String, DataType)], + table_path: &ListingTableUrl, +) -> Result> { + let cols = partition_cols.iter().map(|(name, _)| name.as_str()); + let parsed = parse_partitions_for_path(table_path, &object_meta.location, cols); + + let Some(parsed) = parsed else { + // parse_partitions_for_path already logs a debug message + return Ok(None); + }; + + let partition_values = parsed + .into_iter() + .zip(partition_cols) + .map(|(parsed, (_, datatype))| { + ScalarValue::try_from_string(parsed.to_string(), datatype) + }) + .collect::>>()?; + + let mut pf: PartitionedFile = object_meta.into(); + pf.partition_values = partition_values; + + Ok(Some(pf)) +} + /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. /// `filters` should only contain expressions that can be evaluated @@ -424,80 +398,48 @@ pub async fn pruned_partition_list<'a>( file_extension: &'a str, partition_cols: &'a [(String, DataType)], ) -> Result>> { - // if no partition col => simply list all the files - if partition_cols.is_empty() { - if !filters.is_empty() { - return internal_err!( - "Got partition filters for unpartitioned table {}", - table_path - ); - } - return Ok(Box::pin( - table_path - .list_all_files(ctx, store, file_extension) - .await? - .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)) - .map_ok(|object_meta| object_meta.into()), - )); - } - - let partition_prefix = evaluate_partition_prefix(partition_cols, filters); - - let partitions = - list_partitions(store, table_path, partition_cols.len(), partition_prefix) - .await?; - debug!("Listed {} partitions", partitions.len()); + let prefix = if !partition_cols.is_empty() { + evaluate_partition_prefix(partition_cols, filters) + } else { + None + }; - let pruned = - prune_partitions(table_path, partitions, filters, partition_cols).await?; + let objects = table_path + .list_prefixed_files(ctx, store, prefix, file_extension) + .await? + .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)); - debug!("Pruning yielded {} partitions", pruned.len()); + if partition_cols.is_empty() { + assert_or_internal_err!( + filters.is_empty(), + "Got partition filters for unpartitioned table {}", + table_path + ); - let stream = futures::stream::iter(pruned) - .map(move |partition: Partition| async move { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols); + // if no partition col => simply list all the files + Ok(objects.map_ok(|object_meta| object_meta.into()).boxed()) + } else { + let df_schema = DFSchema::from_unqualified_fields( + partition_cols + .iter() + .map(|(n, d)| Field::new(n, d.clone(), true)) + .collect(), + Default::default(), + )?; - let partition_values = parsed - .into_iter() - .flatten() - .zip(partition_cols) - .map(|(parsed, (_, datatype))| { - ScalarValue::try_from_string(parsed.to_string(), datatype) - }) - .collect::>>()?; - - let files = match partition.files { - Some(files) => files, - None => { - trace!("Recursively listing partition {}", partition.path); - store.list(Some(&partition.path)).try_collect().await? - } - }; - let files = files.into_iter().filter(move |o| { - let extension_match = o.location.as_ref().ends_with(file_extension); - // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) - let glob_match = table_path.contains(&o.location, false); - extension_match && glob_match - }); - - let stream = futures::stream::iter(files.map(move |object_meta| { - Ok(PartitionedFile { + Ok(objects + .try_filter_map(|object_meta| { + futures::future::ready(try_into_partitioned_file( object_meta, - partition_values: partition_values.clone(), - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) - })); - - Ok::<_, DataFusionError>(stream) - }) - .buffer_unordered(CONCURRENCY_LIMIT) - .try_flatten() - .boxed(); - Ok(stream) + partition_cols, + table_path, + )) + }) + .try_filter_map(move |pf| { + futures::future::ready(filter_partitions(pf, filters, &df_schema)) + }) + .boxed()) + } } /// Extract the partition values for the given `file_path` (in the given `table_path`) @@ -541,22 +483,11 @@ pub fn describe_partition(partition: &Partition) -> (&str, usize, Vec<&str>) { #[cfg(test)] mod tests { - use async_trait::async_trait; - use datafusion_common::config::TableOptions; use datafusion_datasource::file_groups::FileGroup; - use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::RuntimeEnv; - use futures::FutureExt; - use object_store::memory::InMemory; - use std::any::Any; use std::ops::Not; use super::*; - use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, - }; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_plan::ExecutionPlan; + use datafusion_expr::{case, col}; #[test] fn test_split_files() { @@ -599,209 +530,6 @@ mod tests { assert_eq!(0, chunks.len()); } - #[tokio::test] - async fn test_pruned_partition_list_empty() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/notparquetfile", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/file.parquet", 100), - ("tablepath/notapartition/file.parquet", 100), - ("tablepath/notmypartition=val1/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .collect::>() - .await; - - assert_eq!(pruned.len(), 0); - } - - #[tokio::test] - async fn test_pruned_partition_list() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/file.parquet", 100), - ("tablepath/mypartition=val2/file.parquet", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/mypartition=val1/other=val3/file.parquet", 100), - ("tablepath/notapartition/file.parquet", 100), - ("tablepath/notmypartition=val1/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/mypartition=val1/file.parquet" - ); - assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/mypartition=val1/other=val3/file.parquet" - ); - assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); - } - - #[tokio::test] - async fn test_pruned_partition_list_multi() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), - ]); - let filter1 = Expr::eq(col("part1"), lit("p1v2")); - let filter2 = Expr::eq(col("part2"), lit("p2v1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2], - ".parquet", - &[ - (String::from("part1"), DataType::Utf8), - (String::from("part2"), DataType::Utf8), - ], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file1.parquet" - ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] - ); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file2.parquet" - ); - assert_eq!( - &f2.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] - ); - } - - #[tokio::test] - async fn test_list_partition() { - let (store, _) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), - ]); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 0, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec![]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 1, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 2, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ( - "tablepath/part1=p1v2/part2=p2v1", - 2, - vec!["file1.parquet", "file2.parquet"] - ), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), - ] - ); - } - #[test] fn test_parse_partitions_for_path() { assert_eq!( @@ -871,6 +599,130 @@ mod tests { ); } + #[test] + fn test_try_into_partitioned_file_valid_partition() { + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year_month=2024-01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!(result.is_some()); + let pf = result.unwrap(); + assert_eq!(pf.partition_values.len(), 1); + assert_eq!( + pf.partition_values[0], + ScalarValue::Utf8(Some("2024-01".to_string())) + ); + } + + #[test] + fn test_try_into_partitioned_file_root_file_skipped() { + // File in root directory (not inside any partition path) should be + // skipped — this is the case where a stale file exists from before + // hive partitioning was added. + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!( + result.is_none(), + "Files outside partition structure should be skipped" + ); + } + + #[test] + fn test_try_into_partitioned_file_wrong_partition_name() { + // File in a directory that doesn't match the expected partition column + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/wrong_col=2024-01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!( + result.is_none(), + "Files with wrong partition column name should be skipped" + ); + } + + #[test] + fn test_try_into_partitioned_file_multiple_partitions() { + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![ + ("year".to_string(), DataType::Utf8), + ("month".to_string(), DataType::Utf8), + ]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year=2024/month=01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!(result.is_some()); + let pf = result.unwrap(); + assert_eq!(pf.partition_values.len(), 2); + assert_eq!( + pf.partition_values[0], + ScalarValue::Utf8(Some("2024".to_string())) + ); + assert_eq!( + pf.partition_values[1], + ScalarValue::Utf8(Some("01".to_string())) + ); + } + + #[test] + fn test_try_into_partitioned_file_partial_partition_skipped() { + // File has first partition but not second — should be skipped + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![ + ("year".to_string(), DataType::Utf8), + ("month".to_string(), DataType::Utf8), + ]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year=2024/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + // File has year=2024 but no month= directory — parse_partitions_for_path + // returns None because the path component "data.parquet" doesn't match + // the expected "month=..." pattern. + assert!( + result.is_none(), + "Files with incomplete partition structure should be skipped" + ); + } + #[test] fn test_expr_applicable_for_cols() { assert!(expr_applicable_for_cols( @@ -1016,86 +868,4 @@ mod tests { Some(Path::from("a=1970-01-05")), ); } - - pub fn make_test_store_and_state( - files: &[(&str, u64)], - ) -> (Arc, Arc) { - let memory = InMemory::new(); - - for (name, size) in files { - memory - .put(&Path::from(*name), vec![0; *size as usize].into()) - .now_or_never() - .unwrap() - .unwrap(); - } - - (Arc::new(memory), Arc::new(MockSession {})) - } - - struct MockSession {} - - #[async_trait] - impl Session for MockSession { - fn session_id(&self) -> &str { - unimplemented!() - } - - fn config(&self) -> &SessionConfig { - unimplemented!() - } - - async fn create_physical_plan( - &self, - _logical_plan: &LogicalPlan, - ) -> Result> { - unimplemented!() - } - - fn create_physical_expr( - &self, - _expr: Expr, - _df_schema: &DFSchema, - ) -> Result> { - unimplemented!() - } - - fn scalar_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn aggregate_functions( - &self, - ) -> &std::collections::HashMap> { - unimplemented!() - } - - fn window_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn runtime_env(&self) -> &Arc { - unimplemented!() - } - - fn execution_props(&self) -> &ExecutionProps { - unimplemented!() - } - - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - - fn table_options(&self) -> &TableOptions { - unimplemented!() - } - - fn table_options_mut(&mut self) -> &mut TableOptions { - unimplemented!() - } - - fn task_ctx(&self) -> Arc { - unimplemented!() - } - } } diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 90d04b46b8067..9efb5aa96267e 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -31,4 +32,4 @@ mod table; pub use config::{ListingTableConfig, SchemaSource}; pub use options::ListingOptions; -pub use table::ListingTable; +pub use table::{ListFilesResult, ListingTable}; diff --git a/datafusion/catalog-listing/src/options.rs b/datafusion/catalog-listing/src/options.rs index 7da8005f90ec2..146f98d62335e 100644 --- a/datafusion/catalog-listing/src/options.rs +++ b/datafusion/catalog-listing/src/options.rs @@ -18,12 +18,12 @@ use arrow::datatypes::{DataType, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::plan_err; -use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::file_format::FileFormat; use datafusion_execution::config::SessionConfig; use datafusion_expr::SortExpr; use futures::StreamExt; -use futures::{future, TryStreamExt}; +use futures::{TryStreamExt, future}; use itertools::Itertools; use std::sync::Arc; diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs index 95f9523d4401c..06ba8c8113fac 100644 --- a/datafusion/catalog-listing/src/table.rs +++ b/datafusion/catalog-listing/src/table.rs @@ -23,19 +23,18 @@ use async_trait::async_trait; use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_datafusion_err, plan_err, project_schema, Constraints, DataFusionError, - SchemaExt, Statistics, + Constraints, SchemaExt, Statistics, internal_datafusion_err, plan_err, project_schema, }; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::file_sink_config::FileSinkConfig; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, -}; +use datafusion_datasource::file_sink_config::{FileOutputMode, FileSinkConfig}; +#[expect(deprecated)] +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_datasource::{ - compute_all_files_statistics, ListingTableUrl, PartitionedFile, + ListingTableUrl, PartitionedFile, TableSchema, compute_all_files_statistics, }; +use datafusion_execution::cache::TableScopedPath; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_expr::dml::InsertOp; @@ -44,14 +43,24 @@ use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::create_lex_ordering; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::ExecutionPlan; -use futures::{future, stream, Stream, StreamExt, TryStreamExt}; +use datafusion_physical_plan::empty::EmptyExec; +use futures::{Stream, StreamExt, TryStreamExt, future, stream}; use object_store::ObjectStore; -use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +/// Result of a file listing operation from [`ListingTable::list_files_for_scan`]. +#[derive(Debug)] +pub struct ListFilesResult { + /// File groups organized by the partitioning strategy. + pub file_groups: Vec, + /// Aggregated statistics for all files. + pub statistics: Statistics, + /// Whether files are grouped by partition values (enables Hash partitioning). + pub grouped_by_partition: bool, +} + /// Built in [`TableProvider`] that reads data from one or more files as a single table. /// /// The files are read using an [`ObjectStore`] instance, for example from @@ -178,13 +187,11 @@ pub struct ListingTable { /// The SQL definition for this table, if any definition: Option, /// Cache for collected file statistics - collected_statistics: FileStatisticsCache, + collected_statistics: Arc, /// Constraints applied to this table constraints: Constraints, /// Column default expressions for columns that are not physically present in the data files column_defaults: HashMap, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters expr_adapter_factory: Option>, } @@ -227,7 +234,6 @@ impl ListingTable { collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), constraints: Constraints::default(), column_defaults: HashMap::new(), - schema_adapter_factory: config.schema_adapter_factory, expr_adapter_factory: config.expr_adapter_factory, }; @@ -255,7 +261,7 @@ impl ListingTable { /// multiple times in the same session. /// /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. - pub fn with_cache(mut self, cache: Option) -> Self { + pub fn with_cache(mut self, cache: Option>) -> Self { self.collected_statistics = cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); self @@ -282,83 +288,151 @@ impl ListingTable { self.schema_source } - /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// Deprecated: Set the [`SchemaAdapterFactory`] for this [`ListingTable`] /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. + /// `SchemaAdapterFactory` has been removed. Use [`ListingTableConfig::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. /// - /// # Example: Adding Schema Evolution Support - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTable, ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); - /// # let table = ListingTable::try_new(config).unwrap(); - /// let table_with_evolution = table - /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); - /// ``` - /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use ListingTableConfig::with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] pub fn with_schema_adapter_factory( self, - schema_adapter_factory: Arc, + _schema_adapter_factory: Arc, ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this table - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() + // No-op - just return self unchanged + self } - /// Creates a schema adapter for mapping between file and table schemas + /// Deprecated: Returns the [`SchemaAdapterFactory`] used by this [`ListingTable`]. /// - /// Uses the configured schema adapter factory if available, otherwise falls back - /// to the default implementation. - fn create_schema_adapter(&self) -> Box { - let table_schema = self.schema(); - match &self.schema_adapter_factory { - Some(factory) => { - factory.create_with_projected_schema(Arc::clone(&table_schema)) - } - None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), - } + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + /// + /// Always returns `None`. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn schema_adapter_factory(&self) -> Option> { + None } - /// Creates a file source and applies schema adapter factory if available - fn create_file_source_with_schema_adapter( - &self, - ) -> datafusion_common::Result> { - let mut source = self.options.format.file_source(); - // Apply schema adapter to source if available - // - // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. - // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. - if let Some(factory) = &self.schema_adapter_factory { - source = source.with_schema_adapter_factory(Arc::clone(factory))?; - } - Ok(source) + /// Creates a file source for this table + fn create_file_source(&self) -> Arc { + let table_schema = TableSchema::new( + Arc::clone(&self.file_schema), + self.options + .table_partition_cols + .iter() + .map(|(col, field)| Arc::new(Field::new(col, field.clone(), false))) + .collect(), + ); + + self.options.format.file_source(table_schema) } - /// If file_sort_order is specified, creates the appropriate physical expressions + /// Creates output ordering from user-specified file_sort_order or derives + /// from file orderings when user doesn't specify. + /// + /// If user specified `file_sort_order`, that takes precedence. + /// Otherwise, attempts to derive common ordering from file orderings in + /// the provided file groups. pub fn try_create_output_ordering( &self, execution_props: &ExecutionProps, + file_groups: &[FileGroup], ) -> datafusion_common::Result> { - create_lex_ordering( - &self.table_schema, - &self.options.file_sort_order, - execution_props, - ) + // If user specified sort order, use that + if !self.options.file_sort_order.is_empty() { + return create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ); + } + if let Some(ordering) = derive_common_ordering_from_files(file_groups) { + return Ok(vec![ordering]); + } + Ok(vec![]) + } +} + +/// Derives a common ordering from file orderings across all file groups. +/// +/// Returns the common ordering if all files have compatible orderings, +/// otherwise returns None. +/// +/// The function finds the longest common prefix among all file orderings. +/// For example, if files have orderings `[a, b, c]` and `[a, b]`, the common +/// ordering is `[a, b]`. +fn derive_common_ordering_from_files(file_groups: &[FileGroup]) -> Option { + enum CurrentOrderingState { + /// Initial state before processing any files + FirstFile, + /// Some common ordering found so far + SomeOrdering(LexOrdering), + /// No files have ordering + NoOrdering, + } + let mut state = CurrentOrderingState::FirstFile; + + // Collect file orderings and track counts + for group in file_groups { + for file in group.iter() { + state = match (&state, &file.ordering) { + // If this is the first file with ordering, set it as current + (CurrentOrderingState::FirstFile, Some(ordering)) => { + CurrentOrderingState::SomeOrdering(ordering.clone()) + } + (CurrentOrderingState::FirstFile, None) => { + CurrentOrderingState::NoOrdering + } + // If we have an existing ordering, find common prefix with new ordering + (CurrentOrderingState::SomeOrdering(current), Some(ordering)) => { + // Find common prefix between current and new ordering + let prefix_len = current + .as_ref() + .iter() + .zip(ordering.as_ref().iter()) + .take_while(|(a, b)| a == b) + .count(); + if prefix_len == 0 { + log::trace!( + "Cannot derive common ordering: no common prefix between orderings {current:?} and {ordering:?}" + ); + return None; + } else { + let ordering = + LexOrdering::new(current.as_ref()[..prefix_len].to_vec()) + .expect("prefix_len > 0, so ordering must be valid"); + CurrentOrderingState::SomeOrdering(ordering) + } + } + // If one file has ordering and another doesn't, no common ordering + // Return None and log a trace message explaining why + (CurrentOrderingState::SomeOrdering(ordering), None) + | (CurrentOrderingState::NoOrdering, Some(ordering)) => { + log::trace!( + "Cannot derive common ordering: some files have ordering {ordering:?}, others don't" + ); + return None; + } + // Both have no ordering, remain in NoOrdering state + (CurrentOrderingState::NoOrdering, None) => { + CurrentOrderingState::NoOrdering + } + }; + } + } + + match state { + CurrentOrderingState::SomeOrdering(ordering) => Some(ordering), + _ => None, } } @@ -374,10 +448,6 @@ fn can_be_evaluated_for_partition_pruning( #[async_trait] impl TableProvider for ListingTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.table_schema) } @@ -418,7 +488,7 @@ impl TableProvider for ListingTable { .options .table_partition_cols .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .map(|col| Ok(Arc::new(self.table_schema.field_with_name(&col.0)?.clone()))) .collect::>>()?; let table_partition_col_names = table_partition_cols @@ -437,7 +507,11 @@ impl TableProvider for ListingTable { // at the same time. This is because the limit should be applied after the filters are applied. let statistic_file_limit = if filters.is_empty() { limit } else { None }; - let (mut partitioned_file_lists, statistics) = self + let ListFilesResult { + file_groups: mut partitioned_file_lists, + statistics, + grouped_by_partition: partitioned_by_file_group, + } = self .list_files_for_scan(state, &partition_filters, statistic_file_limit) .await?; @@ -447,7 +521,10 @@ impl TableProvider for ListingTable { return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); } - let output_ordering = self.try_create_output_ordering(state.execution_props())?; + let output_ordering = self.try_create_output_ordering( + state.execution_props(), + &partitioned_file_lists, + )?; match state .config_options() .execution @@ -469,7 +546,9 @@ impl TableProvider for ListingTable { if new_groups.len() <= self.options.target_partitions { partitioned_file_lists = new_groups; } else { - log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + log::debug!( + "attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered" + ) } } None => {} // no ordering required @@ -483,7 +562,7 @@ impl TableProvider for ListingTable { ))))); }; - let file_source = self.create_file_source_with_schema_adapter()?; + let file_source = self.create_file_source(); // create the execution plan let plan = self @@ -491,20 +570,16 @@ impl TableProvider for ListingTable { .format .create_physical_plan( state, - FileScanConfigBuilder::new( - object_store_url, - Arc::clone(&self.file_schema), - file_source, - ) - .with_file_groups(partitioned_file_lists) - .with_constraints(self.constraints.clone()) - .with_statistics(statistics) - .with_projection_indices(projection) - .with_limit(limit) - .with_output_ordering(output_ordering) - .with_table_partition_cols(table_partition_cols) - .with_expr_adapter(self.expr_adapter_factory.clone()) - .build(), + FileScanConfigBuilder::new(object_store_url, file_source) + .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) + .with_statistics(statistics) + .with_projection_indices(projection)? + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_expr_adapter(self.expr_adapter_factory.clone()) + .with_partitioned_by_file_group(partitioned_by_file_group) + .build(), ) .await?; @@ -574,6 +649,15 @@ impl TableProvider for ListingTable { let keep_partition_by_columns = state.config_options().execution.keep_partition_by_columns; + // Invalidate cache entries for this table if they exist + if let Some(lfc) = state.runtime_env().cache_manager.get_list_files_cache() { + let key = TableScopedPath { + table: table_path.get_table_ref().clone(), + path: table_path.prefix().clone(), + }; + let _ = lfc.remove(&key); + } + // Sink related option, apart from format let config = FileSinkConfig { original_url: String::default(), @@ -585,9 +669,11 @@ impl TableProvider for ListingTable { insert_op, keep_partition_by_columns, file_extension: self.options().format.get_ext(), + file_output_mode: FileOutputMode::Automatic, }; - let orderings = self.try_create_output_ordering(state.execution_props())?; + // For writes, we only use user-specified ordering (no file groups to derive from) + let orderings = self.try_create_output_ordering(state.execution_props(), &[])?; // It is sufficient to pass only one of the equivalent orderings: let order_requirements = orderings.into_iter().next().map(Into::into); @@ -611,11 +697,15 @@ impl ListingTable { ctx: &'a dyn Session, filters: &'a [Expr], limit: Option, - ) -> datafusion_common::Result<(Vec, Statistics)> { + ) -> datafusion_common::Result { let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { - return Ok((vec![], Statistics::new_unknown(&self.file_schema))); + return Ok(ListFilesResult { + file_groups: vec![], + statistics: Statistics::new_unknown(&self.file_schema), + grouped_by_partition: false, + }); }; // list files (with partitions) let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { @@ -632,16 +722,19 @@ impl ListingTable { let meta_fetch_concurrency = ctx.config_options().execution.meta_fetch_concurrency; let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); - // collect the statistics if required by the config + // collect the statistics and ordering if required by the config let files = file_list .map(|part_file| async { let part_file = part_file?; - let statistics = if self.options.collect_stat { - self.do_collect_statistics(ctx, &store, &part_file).await? + let (statistics, ordering) = if self.options.collect_stat { + self.do_collect_statistics_and_ordering(ctx, &store, &part_file) + .await? } else { - Arc::new(Statistics::new_unknown(&self.file_schema)) + (Arc::new(Statistics::new_unknown(&self.file_schema)), None) }; - Ok(part_file.with_statistics(statistics)) + Ok(part_file + .with_statistics(statistics) + .with_ordering(ordering)) }) .boxed() .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); @@ -649,65 +742,97 @@ impl ListingTable { let (file_group, inexact_stats) = get_files_with_limit(files, limit, self.options.collect_stat).await?; - let file_groups = file_group.split_files(self.options.target_partitions); - let (mut file_groups, mut stats) = compute_all_files_statistics( + // Threshold: 0 = disabled, N > 0 = enabled when distinct_keys >= N + // + // When enabled, files are grouped by their Hive partition column values, allowing + // FileScanConfig to declare Hash partitioning. This enables the optimizer to skip + // hash repartitioning for aggregates and joins on partition columns. + let threshold = ctx.config_options().optimizer.preserve_file_partitions; + + let (file_groups, grouped_by_partition) = if threshold > 0 + && !self.options.table_partition_cols.is_empty() + { + let grouped = + file_group.group_by_partition_values(self.options.target_partitions); + if grouped.len() >= threshold { + (grouped, true) + } else { + let all_files: Vec<_> = + grouped.into_iter().flat_map(|g| g.into_inner()).collect(); + ( + FileGroup::new(all_files).split_files(self.options.target_partitions), + false, + ) + } + } else { + ( + file_group.split_files(self.options.target_partitions), + false, + ) + }; + + let (file_groups, stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, )?; - let schema_adapter = self.create_schema_adapter(); - let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; - - stats.column_statistics = - schema_mapper.map_column_statistics(&stats.column_statistics)?; - file_groups.iter_mut().try_for_each(|file_group| { - if let Some(stat) = file_group.statistics_mut() { - stat.column_statistics = - schema_mapper.map_column_statistics(&stat.column_statistics)?; - } - Ok::<_, DataFusionError>(()) - })?; - Ok((file_groups, stats)) + // Note: Statistics already include both file columns and partition columns. + // PartitionedFile::with_statistics automatically appends exact partition column + // statistics (min=max=partition_value, null_count=0, distinct_count=1) computed + // from partition_values. + Ok(ListFilesResult { + file_groups, + statistics: stats, + grouped_by_partition, + }) } - /// Collects statistics for a given partitioned file. + /// Collects statistics and ordering for a given partitioned file. /// - /// This method first checks if the statistics for the given file are already cached. - /// If they are, it returns the cached statistics. - /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics( + /// This method checks if statistics are cached. If cached, it returns the + /// cached statistics and infers ordering separately. If not cached, it infers + /// both statistics and ordering in a single metadata read for efficiency. + async fn do_collect_statistics_and_ordering( &self, ctx: &dyn Session, store: &Arc, part_file: &PartitionedFile, - ) -> datafusion_common::Result> { - match self - .collected_statistics - .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) + ) -> datafusion_common::Result<(Arc, Option)> { + use datafusion_execution::cache::cache_manager::CachedFileMetadata; + + let path = &part_file.object_meta.location; + let meta = &part_file.object_meta; + + // Check cache first - if we have valid cached statistics and ordering + if let Some(cached) = self.collected_statistics.get(path) + && cached.is_valid_for(meta) { - Some(statistics) => Ok(statistics), - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - store, - Arc::clone(&self.file_schema), - &part_file.object_meta, - ) - .await?; - let statistics = Arc::new(statistics); - self.collected_statistics.put_with_extra( - &part_file.object_meta.location, - Arc::clone(&statistics), - &part_file.object_meta, - ); - Ok(statistics) - } + // Return cached statistics and ordering + return Ok((Arc::clone(&cached.statistics), cached.ordering.clone())); } + + // Cache miss or invalid: fetch both statistics and ordering in a single metadata read + let file_meta = self + .options + .format + .infer_stats_and_ordering(ctx, store, Arc::clone(&self.file_schema), meta) + .await?; + + let statistics = Arc::new(file_meta.statistics); + + // Store in cache + self.collected_statistics.put( + path, + CachedFileMetadata::new( + meta.clone(), + Arc::clone(&statistics), + file_meta.ordering.clone(), + ), + ); + + Ok((statistics, file_meta.ordering)) } } @@ -756,28 +881,25 @@ async fn get_files_with_limit( let file = file_result?; // Update file statistics regardless of state - if collect_stats { - if let Some(file_stats) = &file.statistics { - num_rows = if file_group.is_empty() { - // For the first file, just take its row count - file_stats.num_rows - } else { - // For subsequent files, accumulate the counts - num_rows.add(&file_stats.num_rows) - }; - } + if collect_stats && let Some(file_stats) = &file.statistics { + num_rows = if file_group.is_empty() { + // For the first file, just take its row count + file_stats.num_rows + } else { + // For subsequent files, accumulate the counts + num_rows.add(&file_stats.num_rows) + }; } // Always add the file to our group file_group.push(file); // Check if we've hit the limit (if one was specified) - if let Some(limit) = limit { - if let Precision::Exact(row_count) = num_rows { - if row_count > limit { - state = ProcessingState::ReachedLimit; - } - } + if let Some(limit) = limit + && let Precision::Exact(row_count) = num_rows + && row_count > limit + { + state = ProcessingState::ReachedLimit; } } // If we still have files in the stream, it means that the limit kicked @@ -786,3 +908,145 @@ async fn get_files_with_limit( let inexact_stats = all_files.next().await.is_some(); Ok((file_group, inexact_stats)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::SortOptions; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + /// Helper to create a PhysicalSortExpr + fn sort_expr( + name: &str, + idx: usize, + descending: bool, + nulls_first: bool, + ) -> PhysicalSortExpr { + PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + SortOptions { + descending, + nulls_first, + }, + ) + } + + /// Helper to create a LexOrdering (unwraps the Option) + fn lex_ordering(exprs: Vec) -> LexOrdering { + LexOrdering::new(exprs).expect("expected non-empty ordering") + } + + /// Helper to create a PartitionedFile with optional ordering + fn create_file(name: &str, ordering: Option) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_ordering(ordering) + } + + #[test] + fn test_derive_common_ordering_all_files_same_ordering() { + // All files have the same ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![ + FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering.clone())), + create_file("f2.parquet", Some(ordering.clone())), + ]), + FileGroup::new(vec![create_file("f3.parquet", Some(ordering.clone()))]), + ]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } + + #[test] + fn test_derive_common_ordering_common_prefix() { + // Files have different orderings but share a common prefix + let ordering_abc = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + sort_expr("c", 2, false, true), + ]); + let ordering_ab = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + ]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_abc)), + create_file("f2.parquet", Some(ordering_ab.clone())), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering_ab)); + } + + #[test] + fn test_derive_common_ordering_no_common_prefix() { + // Files have completely different orderings -> returns None + let ordering_a = lex_ordering(vec![sort_expr("a", 0, false, true)]); + let ordering_b = lex_ordering(vec![sort_expr("b", 1, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_a)), + create_file("f2.parquet", Some(ordering_b)), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_mixed_with_none() { + // Some files have ordering, some don't -> returns None + let ordering = lex_ordering(vec![sort_expr("a", 0, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering)), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_all_none() { + // No files have ordering -> returns None + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", None), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_empty_groups() { + // Empty file groups -> returns None + let file_groups: Vec = vec![]; + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_single_file() { + // Single file with ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![FileGroup::new(vec![create_file( + "f1.parquet", + Some(ordering.clone()), + )])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } +} diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index a1db45654be01..1009e9aee477b 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -49,5 +49,8 @@ object_store = { workspace = true } parking_lot = { workspace = true } tokio = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs index 1c830c976d8b8..87b7b7c3431a1 100644 --- a/datafusion/catalog/src/async.rs +++ b/datafusion/catalog/src/async.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::{error::Result, not_impl_err, HashMap, TableReference}; +use datafusion_common::{HashMap, TableReference, error::Result, not_impl_err}; use datafusion_execution::config::SessionConfig; use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; @@ -37,10 +37,6 @@ impl SchemaProvider for ResolvedSchemaProvider { self.owner_name.as_deref() } - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn table_names(&self) -> Vec { self.cached_tables.keys().cloned().collect() } @@ -60,7 +56,9 @@ impl SchemaProvider for ResolvedSchemaProvider { } fn deregister_table(&self, name: &str) -> Result>> { - not_impl_err!("Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported") + not_impl_err!( + "Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported" + ) } fn table_exist(&self, name: &str) -> bool { @@ -113,10 +111,6 @@ struct ResolvedCatalogProvider { cached_schemas: HashMap>, } impl CatalogProvider for ResolvedCatalogProvider { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn schema_names(&self) -> Vec { self.cached_schemas.keys().cloned().collect() } @@ -158,10 +152,6 @@ struct ResolvedCatalogProviderList { cached_catalogs: HashMap>, } impl CatalogProviderList for ResolvedCatalogProviderList { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn register_catalog( &self, _name: String, @@ -193,7 +183,7 @@ impl CatalogProviderList for ResolvedCatalogProviderList { /// /// See the [remote_catalog.rs] for an end to end example /// -/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs #[async_trait] pub trait AsyncSchemaProvider: Send + Sync { /// Lookup a table in the schema provider @@ -422,17 +412,14 @@ pub trait AsyncCatalogProviderList: Send + Sync { #[cfg(test)] mod tests { - use std::{ - any::Any, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, + use std::sync::{ + Arc, + atomic::{AtomicU32, Ordering}, }; use arrow::datatypes::SchemaRef; use async_trait::async_trait; - use datafusion_common::{error::Result, Statistics, TableReference}; + use datafusion_common::{Statistics, TableReference, error::Result}; use datafusion_execution::config::SessionConfig; use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::ExecutionPlan; @@ -445,10 +432,6 @@ mod tests { struct MockTableProvider {} #[async_trait] impl TableProvider for MockTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef { unimplemented!() diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 71b9eccf9d657..34cdf74440cb3 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -20,8 +20,8 @@ use std::fmt::Debug; use std::sync::Arc; pub use crate::schema::SchemaProvider; -use datafusion_common::not_impl_err; use datafusion_common::Result; +use datafusion_common::not_impl_err; /// Represents a catalog, comprising a number of named schemas. /// @@ -61,7 +61,7 @@ use datafusion_common::Result; /// schemas and tables exist. /// /// [Delta Lake]: https://delta.io/ -/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs /// /// The [`CatalogProvider`] can support this use case, but it takes some care. /// The planning APIs in DataFusion are not `async` and thus network IO can not @@ -100,16 +100,12 @@ use datafusion_common::Result; /// /// [`datafusion-cli`]: https://datafusion.apache.org/user-guide/cli/index.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 -/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/catalog.rs +/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/catalog.rs /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// /// [`TableProvider`]: crate::TableProvider -pub trait CatalogProvider: Debug + Sync + Send { - /// Returns the catalog provider as [`Any`] - /// so that it can be downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - +pub trait CatalogProvider: Any + Debug + Sync + Send { /// Retrieves the list of available schema names in this catalog. fn schema_names(&self) -> Vec; @@ -152,15 +148,31 @@ pub trait CatalogProvider: Debug + Sync + Send { } } +impl dyn CatalogProvider { + /// Returns `true` if the catalog provider is of type `T`. + /// + /// Prefer this over `downcast_ref::().is_some()`. Works correctly when + /// called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast this catalog provider to a concrete type `T`, + /// returning `None` if the provider is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} + /// Represent a list of named [`CatalogProvider`]s. /// /// Please see the documentation on [`CatalogProvider`] for details of /// implementing a custom catalog. -pub trait CatalogProviderList: Debug + Sync + Send { - /// Returns the catalog list as [`Any`] - /// so that it can be downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - +pub trait CatalogProviderList: Any + Debug + Sync + Send { /// Adds a new catalog to this catalog list /// If a catalog of the same name existed before, it is replaced in the list and returned. fn register_catalog( @@ -175,3 +187,23 @@ pub trait CatalogProviderList: Debug + Sync + Send { /// Retrieves a specific catalog by name, provided it exists. fn catalog(&self, name: &str) -> Option>; } + +impl dyn CatalogProviderList { + /// Returns `true` if the catalog provider list is of type `T`. + /// + /// Prefer this over `downcast_ref::().is_some()`. Works correctly when + /// called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast this catalog provider list to a concrete type `T`, + /// returning `None` if the provider list is not of that type. + /// + /// Works correctly when called on `Arc` via + /// auto-deref, unlike `(&arc as &dyn Any).downcast_ref::()` which would + /// attempt to downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} diff --git a/datafusion/catalog/src/cte_worktable.rs b/datafusion/catalog/src/cte_worktable.rs index d6b2a453118c9..dd313ebb4cbff 100644 --- a/datafusion/catalog/src/cte_worktable.rs +++ b/datafusion/catalog/src/cte_worktable.rs @@ -17,20 +17,17 @@ //! CteWorkTable implementation used for recursive queries +use std::borrow::Cow; use std::sync::Arc; -use std::{any::Any, borrow::Cow}; -use crate::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_physical_plan::work_table::WorkTableExec; - -use datafusion_physical_plan::ExecutionPlan; - use datafusion_common::error::Result; use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::work_table::WorkTableExec; -use crate::TableProvider; +use crate::{ScanArgs, ScanResult, Session, TableProvider}; /// The temporary working table where the previous iteration of a recursive query is stored /// Naming is based on PostgreSQL's implementation. @@ -67,10 +64,6 @@ impl CteWorkTable { #[async_trait] impl TableProvider for CteWorkTable { - fn as_any(&self) -> &dyn Any { - self - } - fn get_logical_plan(&'_ self) -> Option> { None } @@ -85,16 +78,28 @@ impl TableProvider for CteWorkTable { async fn scan( &self, - _state: &dyn Session, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, ) -> Result> { - // TODO: pushdown filters and limits - Ok(Arc::new(WorkTableExec::new( + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + _state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + Ok(ScanResult::new(Arc::new(WorkTableExec::new( self.name.clone(), Arc::clone(&self.table_schema), - ))) + args.projection().map(|p| p.to_vec()), + )?))) } fn supports_filters_pushdown( diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 11963c06c88f5..60f85891d66e6 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -17,13 +17,13 @@ //! Default TableSource implementation used in DataFusion physical plans +use std::borrow::Cow; use std::sync::Arc; -use std::{any::Any, borrow::Cow}; use crate::TableProvider; use arrow::datatypes::SchemaRef; -use datafusion_common::{internal_err, Constraints}; +use datafusion_common::{Constraints, internal_err}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; /// Implements [`TableSource`] for a [`TableProvider`] @@ -46,12 +46,6 @@ impl DefaultTableSource { } impl TableSource for DefaultTableSource { - /// Returns the table source as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any { - self - } - /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef { self.table_provider.schema() @@ -97,11 +91,7 @@ pub fn provider_as_source( pub fn source_as_provider( source: &Arc, ) -> datafusion_common::Result> { - match source - .as_ref() - .as_any() - .downcast_ref::() - { + match source.as_ref().downcast_ref::() { Some(source) => Ok(Arc::clone(&source.table_provider)), _ => internal_err!("TableSource was not DefaultTableSource"), } @@ -117,10 +107,6 @@ fn preserves_table_type() { #[async_trait] impl TableProvider for TestTempTable { - fn as_any(&self) -> &dyn Any { - self - } - fn table_type(&self) -> TableType { TableType::Temporary } diff --git a/datafusion/catalog/src/dynamic_file/catalog.rs b/datafusion/catalog/src/dynamic_file/catalog.rs index ccccb9762eb4c..f93bd35cd7f0a 100644 --- a/datafusion/catalog/src/dynamic_file/catalog.rs +++ b/datafusion/catalog/src/dynamic_file/catalog.rs @@ -19,7 +19,6 @@ use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; use async_trait::async_trait; -use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -42,10 +41,6 @@ impl DynamicFileCatalog { } impl CatalogProviderList for DynamicFileCatalog { - fn as_any(&self) -> &dyn Any { - self - } - fn register_catalog( &self, name: String, @@ -87,10 +82,6 @@ impl DynamicFileCatalogProvider { } impl CatalogProvider for DynamicFileCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema_names(&self) -> Vec { self.inner.schema_names() } @@ -137,10 +128,6 @@ impl DynamicFileSchemaProvider { #[async_trait] impl SchemaProvider for DynamicFileSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { self.inner.table_names() } diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/catalog/src/empty.rs similarity index 89% rename from datafusion/core/src/datasource/empty.rs rename to datafusion/catalog/src/empty.rs index 77686c5eb7c27..1ff36ecf360a2 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/catalog/src/empty.rs @@ -17,19 +17,17 @@ //! [`EmptyTable`] useful for testing. -use std::any::Any; use std::sync::Arc; use arrow::datatypes::*; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_common::project_schema; - -use crate::datasource::{TableProvider, TableType}; -use crate::error::Result; -use crate::logical_expr::Expr; -use datafusion_physical_plan::empty::EmptyExec; +use datafusion_common::{Result, project_schema}; +use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; + +use crate::Session; +use crate::TableProvider; /// An empty plan that is useful for testing and generating plans /// without mapping them to actual data. @@ -57,10 +55,6 @@ impl EmptyTable { #[async_trait] impl TableProvider for EmptyTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index d733551f44051..34c677c3dd43e 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -24,23 +24,27 @@ use crate::{CatalogProviderList, SchemaProvider, TableProvider}; use arrow::array::builder::{BooleanBuilder, UInt8Builder}; use arrow::{ array::{StringBuilder, UInt64Builder}, - datatypes::{DataType, Field, Schema, SchemaRef}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, record_batch::RecordBatch, }; use async_trait::async_trait; +use datafusion_common::DataFusionError; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; use datafusion_common::types::NativeType; -use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; -use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::{ + AggregateUDF, ReturnFieldArgs, ScalarUDF, Signature, TypeSignature, WindowUDF, +}; use datafusion_expr::{TableType, Volatility}; +use datafusion_physical_plan::SendableRecordBatchStream; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::SendableRecordBatchStream; use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; pub const INFORMATION_SCHEMA: &str = "information_schema"; pub(crate) const TABLES: &str = "tables"; @@ -137,11 +141,11 @@ impl InformationSchemaConfig { let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); for schema_name in catalog.schema_names() { - if schema_name != INFORMATION_SCHEMA { - if let Some(schema) = catalog.schema(&schema_name) { - let schema_owner = schema.owner_name(); - builder.add_schemata(&catalog_name, &schema_name, schema_owner); - } + if schema_name != INFORMATION_SCHEMA + && let Some(schema) = catalog.schema(&schema_name) + { + let schema_owner = schema.owner_name(); + builder.add_schemata(&catalog_name, &schema_name, schema_owner); } } } @@ -215,11 +219,16 @@ impl InformationSchemaConfig { fn make_df_settings( &self, config_options: &ConfigOptions, + runtime_env: &Arc, builder: &mut InformationSchemaDfSettingsBuilder, ) { for entry in config_options.entries() { builder.add_setting(entry); } + // Add runtime configuration entries + for entry in runtime_env.config_entries() { + builder.add_setting(entry); + } } fn make_routines( @@ -245,7 +254,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udf.signature()), - return_type, + return_type.as_ref(), "SCALAR", udf.documentation().map(|d| d.description.to_string()), udf.documentation().map(|d| d.syntax_example.to_string()), @@ -265,7 +274,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udaf.signature()), - return_type, + return_type.as_ref(), "AGGREGATE", udaf.documentation().map(|d| d.description.to_string()), udaf.documentation().map(|d| d.syntax_example.to_string()), @@ -285,7 +294,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udwf.signature()), - return_type, + return_type.as_ref(), "WINDOW", udwf.documentation().map(|d| d.description.to_string()), udwf.documentation().map(|d| d.syntax_example.to_string()), @@ -415,14 +424,28 @@ fn get_udf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let scalar_arguments = vec![None; arg_fields.len()]; let return_type = udf - .return_type(&arg_types) - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -441,14 +464,24 @@ fn get_udaf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); let return_type = udaf - .return_type(&arg_types) - .ok() - .map(|t| remove_native_type_prefix(NativeType::from(t))); + .return_field(&arg_fields) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -467,28 +500,38 @@ fn get_udwf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let return_type = udwf + .field(WindowUDFFieldArgs::new(&arg_fields, udwf.name())) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); - (arg_types, None) + (arg_types, return_type) }) .collect::>()) } } #[inline] -fn remove_native_type_prefix(native_type: NativeType) -> String { +fn remove_native_type_prefix(native_type: &NativeType) -> String { format!("{native_type}") } #[async_trait] impl SchemaProvider for InformationSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { INFORMATION_SCHEMA_TABLES .iter() @@ -679,7 +722,7 @@ impl InformationSchemaViewBuilder { catalog_name: impl AsRef, schema_name: impl AsRef, table_name: impl AsRef, - definition: Option>, + definition: Option<&(impl AsRef + ?Sized)>, ) { // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); @@ -1060,7 +1103,12 @@ impl PartitionStream for InformationSchemaDfSettings { // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables - config.make_df_settings(ctx.session_config().options(), &mut builder); + let runtime_env = ctx.runtime_env(); + config.make_df_settings( + ctx.session_config().options(), + &runtime_env, + &mut builder, + ); Ok(builder.finish()) }), )) @@ -1156,7 +1204,7 @@ struct InformationSchemaRoutinesBuilder { } impl InformationSchemaRoutinesBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_routine( &mut self, catalog_name: impl AsRef, @@ -1164,7 +1212,7 @@ impl InformationSchemaRoutinesBuilder { routine_name: impl AsRef, routine_type: impl AsRef, is_deterministic: bool, - data_type: Option>, + data_type: Option<&impl AsRef>, function_type: impl AsRef, description: Option>, syntax_example: Option>, @@ -1290,7 +1338,7 @@ struct InformationSchemaParametersBuilder { } impl InformationSchemaParametersBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_parameter( &mut self, specific_catalog: impl AsRef, @@ -1298,7 +1346,7 @@ impl InformationSchemaParametersBuilder { specific_name: impl AsRef, ordinal_position: u64, parameter_mode: impl AsRef, - parameter_name: Option>, + parameter_name: Option<&(impl AsRef + ?Sized)>, data_type: impl AsRef, parameter_default: Option>, is_variadic: bool, @@ -1397,11 +1445,9 @@ mod tests { // InformationSchemaConfig::make_tables used this before `table_type` // existed but should not, as it may be expensive. async fn table(&self, _: &str) -> Result>> { - panic!("InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type") - } - - fn as_any(&self) -> &dyn Any { - unimplemented!("not required for these tests") + panic!( + "InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type" + ) } fn table_names(&self) -> Vec { @@ -1414,10 +1460,6 @@ mod tests { } impl CatalogProviderList for Fixture { - fn as_any(&self) -> &dyn Any { - unimplemented!("not required for these tests") - } - fn register_catalog( &self, _: String, @@ -1436,10 +1478,6 @@ mod tests { } impl CatalogProvider for Fixture { - fn as_any(&self) -> &dyn Any { - unimplemented!("not required for these tests") - } - fn schema_names(&self) -> Vec { vec!["aschema".to_string()] } diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 1c5e38438724e..33d54b7cb89d5 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -23,6 +23,7 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Interfaces and default implementations of catalogs and schemas. //! @@ -33,6 +34,7 @@ pub mod cte_worktable; pub mod default_table_source; +pub mod empty; pub mod information_schema; pub mod listing_schema; pub mod memory; @@ -46,13 +48,13 @@ mod dynamic_file; mod schema; mod table; +pub use r#async::*; pub use catalog::*; pub use datafusion_session::Session; pub use dynamic_file::catalog::*; pub use memory::{ MemTable, MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider, }; -pub use r#async::*; pub use schema::*; pub use table::*; diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index af96cfc15fc82..d38fe659aaa97 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -17,7 +17,6 @@ //! [`ListingSchemaProvider`]: [`SchemaProvider`] that scans ObjectStores for tables automatically -use std::any::Any; use std::collections::HashSet; use std::path::Path; use std::sync::{Arc, Mutex}; @@ -26,7 +25,7 @@ use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; use datafusion_common::{ - internal_datafusion_err, DFSchema, DataFusionError, HashMap, TableReference, + DFSchema, DataFusionError, HashMap, TableReference, internal_datafusion_err, }; use datafusion_expr::CreateExternalTable; @@ -127,22 +126,13 @@ impl ListingSchemaProvider { .factory .create( state, - &CreateExternalTable { - schema: Arc::new(DFSchema::empty()), + &CreateExternalTable::builder( name, - location: table_url, - file_type: self.format.clone(), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: Default::default(), - constraints: Default::default(), - column_defaults: Default::default(), - }, + table_url, + self.format.clone(), + Arc::new(DFSchema::empty()), + ) + .build(), ) .await?; let _ = @@ -155,10 +145,6 @@ impl ListingSchemaProvider { #[async_trait] impl SchemaProvider for ListingSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { self.tables .lock() diff --git a/datafusion/catalog/src/memory/catalog.rs b/datafusion/catalog/src/memory/catalog.rs index b71888c54e9d6..ebe6b9dfa0ebc 100644 --- a/datafusion/catalog/src/memory/catalog.rs +++ b/datafusion/catalog/src/memory/catalog.rs @@ -21,7 +21,6 @@ use crate::{CatalogProvider, CatalogProviderList, SchemaProvider}; use dashmap::DashMap; use datafusion_common::exec_err; -use std::any::Any; use std::sync::Arc; /// Simple in-memory list of catalogs @@ -47,10 +46,6 @@ impl Default for MemoryCatalogProviderList { } impl CatalogProviderList for MemoryCatalogProviderList { - fn as_any(&self) -> &dyn Any { - self - } - fn register_catalog( &self, name: String, @@ -90,10 +85,6 @@ impl Default for MemoryCatalogProvider { } impl CatalogProvider for MemoryCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema_names(&self) -> Vec { self.schemas.iter().map(|s| s.key().clone()).collect() } diff --git a/datafusion/catalog/src/memory/schema.rs b/datafusion/catalog/src/memory/schema.rs index f1b3628f7affc..46b0beb440613 100644 --- a/datafusion/catalog/src/memory/schema.rs +++ b/datafusion/catalog/src/memory/schema.rs @@ -20,8 +20,7 @@ use crate::{SchemaProvider, TableProvider}; use async_trait::async_trait; use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError}; -use std::any::Any; +use datafusion_common::{DataFusionError, exec_err}; use std::sync::Arc; /// Simple in-memory implementation of a schema. @@ -47,10 +46,6 @@ impl Default for MemorySchemaProvider { #[async_trait] impl SchemaProvider for MemorySchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn table_names(&self) -> Vec { self.tables .iter() diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 90224f6a37bc3..8102c15079658 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -17,27 +17,36 @@ //! [`MemTable`] for querying `Vec` by DataFusion. -use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; use crate::TableProvider; -use arrow::datatypes::SchemaRef; +use arrow::array::{ + Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array, +}; +use arrow::compute::kernels::zip::zip; +use arrow::compute::{and, filter_record_batch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::error::Result; -use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; use datafusion_expr::{Expr, SortExpr, TableType}; -use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; +use datafusion_physical_expr::{ + LexOrdering, create_physical_expr, create_physical_sort_exprs, +}; use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PhysicalExpr, PlanProperties, common, }; use datafusion_session::Session; @@ -204,10 +213,6 @@ impl MemTable { #[async_trait] impl TableProvider for MemTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -295,4 +300,338 @@ impl TableProvider for MemTable { fn get_column_default(&self, column: &str) -> Option<&Expr> { self.column_defaults.get(column) } + + async fn delete_from( + &self, + state: &dyn Session, + filters: Vec, + ) -> Result> { + // Early exit if table has no partitions + if self.batches.is_empty() { + return Ok(Arc::new(DmlResultExec::new(0))); + } + + *self.sort_order.lock() = vec![]; + + let mut total_deleted: u64 = 0; + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + for partition_data in &self.batches { + let mut partition = partition_data.write().await; + let mut new_batches = Vec::with_capacity(partition.len()); + + for batch in partition.iter() { + if batch.num_rows() == 0 { + continue; + } + + // Evaluate filters - None means "match all rows" + let filter_mask = evaluate_filters_to_mask( + &filters, + batch, + &df_schema, + state.execution_props(), + )?; + + let (delete_count, keep_mask) = match filter_mask { + Some(mask) => { + // Count rows where mask is true (will be deleted) + let count = mask.iter().filter(|v| v == &Some(true)).count(); + // Keep rows where predicate is false or NULL (SQL three-valued logic) + let keep: BooleanArray = + mask.iter().map(|v| Some(v != Some(true))).collect(); + (count, keep) + } + None => { + // No filters = delete all rows + ( + batch.num_rows(), + BooleanArray::from(vec![false; batch.num_rows()]), + ) + } + }; + + total_deleted += delete_count as u64; + + let filtered_batch = filter_record_batch(batch, &keep_mask)?; + if filtered_batch.num_rows() > 0 { + new_batches.push(filtered_batch); + } + } + + *partition = new_batches; + } + + Ok(Arc::new(DmlResultExec::new(total_deleted))) + } + + async fn update( + &self, + state: &dyn Session, + assignments: Vec<(String, Expr)>, + filters: Vec, + ) -> Result> { + // Early exit if table has no partitions + if self.batches.is_empty() { + return Ok(Arc::new(DmlResultExec::new(0))); + } + + // Validate column names upfront with clear error messages + let available_columns: Vec<&str> = self + .schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + for (column_name, _) in &assignments { + if self.schema.field_with_name(column_name).is_err() { + return plan_err!( + "UPDATE failed: column '{}' does not exist. Available columns: {}", + column_name, + available_columns.join(", ") + ); + } + } + + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + // Create physical expressions for assignments upfront (outside batch loop) + let physical_assignments: HashMap> = assignments + .iter() + .map(|(name, expr)| { + let physical_expr = + create_physical_expr(expr, &df_schema, state.execution_props())?; + Ok((name.clone(), physical_expr)) + }) + .collect::>()?; + + *self.sort_order.lock() = vec![]; + + let mut total_updated: u64 = 0; + + for partition_data in &self.batches { + let mut partition = partition_data.write().await; + let mut new_batches = Vec::with_capacity(partition.len()); + + for batch in partition.iter() { + if batch.num_rows() == 0 { + continue; + } + + // Evaluate filters - None means "match all rows" + let filter_mask = evaluate_filters_to_mask( + &filters, + batch, + &df_schema, + state.execution_props(), + )?; + + let (update_count, update_mask) = match filter_mask { + Some(mask) => { + // Count rows where mask is true (will be updated) + let count = mask.iter().filter(|v| v == &Some(true)).count(); + // Normalize mask: only true (not NULL) triggers update + let normalized: BooleanArray = + mask.iter().map(|v| Some(v == Some(true))).collect(); + (count, normalized) + } + None => { + // No filters = update all rows + ( + batch.num_rows(), + BooleanArray::from(vec![true; batch.num_rows()]), + ) + } + }; + + total_updated += update_count as u64; + + if update_count == 0 { + new_batches.push(batch.clone()); + continue; + } + + let mut new_columns: Vec = + Vec::with_capacity(batch.num_columns()); + + for field in self.schema.fields() { + let column_name = field.name(); + let original_column = + batch.column_by_name(column_name).ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "Column '{column_name}' not found in batch" + )) + })?; + + let new_column = if let Some(physical_expr) = + physical_assignments.get(column_name.as_str()) + { + // Use evaluate_selection to only evaluate on matching rows. + // This avoids errors (e.g., divide-by-zero) on rows that won't + // be updated. The result is scattered back with nulls for + // non-matching rows, which zip() will replace with originals. + let new_values = + physical_expr.evaluate_selection(batch, &update_mask)?; + let new_array = new_values.into_array(batch.num_rows())?; + + // Convert to &dyn Array which implements Datum + let new_arr: &dyn Array = new_array.as_ref(); + let orig_arr: &dyn Array = original_column.as_ref(); + zip(&update_mask, &new_arr, &orig_arr)? + } else { + Arc::clone(original_column) + }; + + new_columns.push(new_column); + } + + let updated_batch = + ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?; + new_batches.push(updated_batch); + } + + *partition = new_batches; + } + + Ok(Arc::new(DmlResultExec::new(total_updated))) + } +} + +/// Evaluate filter expressions against a batch and return a combined boolean mask. +/// Returns None if filters is empty (meaning "match all rows"). +/// The returned mask has true for rows that match the filter predicates. +fn evaluate_filters_to_mask( + filters: &[Expr], + batch: &RecordBatch, + df_schema: &DFSchema, + execution_props: &datafusion_expr::execution_props::ExecutionProps, +) -> Result> { + if filters.is_empty() { + return Ok(None); + } + + let mut combined_mask: Option = None; + + for filter_expr in filters { + let physical_expr = + create_physical_expr(filter_expr, df_schema, execution_props)?; + + let result = physical_expr.evaluate(batch)?; + let array = result.into_array(batch.num_rows())?; + let bool_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Filter did not evaluate to boolean".to_string(), + ) + })? + .clone(); + + combined_mask = Some(match combined_mask { + Some(existing) => and(&existing, &bool_array)?, + None => bool_array, + }); + } + + Ok(combined_mask) +} + +/// Returns a single row with the count of affected rows. +#[derive(Debug)] +struct DmlResultExec { + rows_affected: u64, + schema: SchemaRef, + properties: Arc, +} + +impl DmlResultExec { + fn new(rows_affected: u64) -> Self { + let schema = Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])); + + let properties = PlanProperties::new( + datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion_physical_plan::execution_plan::EmissionType::Final, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + ); + + Self { + rows_affected, + schema, + properties: Arc::new(properties), + } + } +} + +impl DisplayAs for DmlResultExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!(f, "DmlResultExec: rows_affected={}", self.rows_affected) + } + } + } +} + +impl ExecutionPlan for DmlResultExec { + fn name(&self) -> &str { + "DmlResultExec" + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + // Create a single batch with the count + let count_array = UInt64Array::from(vec![self.rows_affected]); + let batch = ArrowRecordBatch::try_new( + Arc::clone(&self.schema), + vec![Arc::new(count_array) as ArrayRef], + )?; + + // Create a stream that yields just this one batch + let stream = futures::stream::iter(vec![Ok(batch)]); + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream, + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 9ba55256f1824..d99027593ccce 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -19,7 +19,7 @@ //! representing collections of named tables. use async_trait::async_trait; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_err}; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -34,17 +34,13 @@ use datafusion_expr::TableType; /// /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] -pub trait SchemaProvider: Debug + Sync + Send { +pub trait SchemaProvider: Any + Debug + Sync + Send { /// Returns the owner of the Schema, default is None. This value is reported /// as part of `information_tables.schemata fn owner_name(&self) -> Option<&str> { None } - /// Returns this `SchemaProvider` as [`Any`] so that it can be downcast to a - /// specific implementation. - fn as_any(&self) -> &dyn Any; - /// Retrieves the list of available table names in this schema. fn table_names(&self) -> Vec; @@ -68,7 +64,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// /// If a table of the same name was already registered, returns "Table /// already exists" error. - #[allow(unused_variables)] + #[expect(unused_variables)] fn register_table( &self, name: String, @@ -81,7 +77,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// schema and returns the previously registered [`TableProvider`], if any. /// /// If no `name` table exists, returns Ok(None). - #[allow(unused_variables)] + #[expect(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { exec_err!("schema provider does not support deregistering tables") } @@ -89,3 +85,23 @@ pub trait SchemaProvider: Debug + Sync + Send { /// Returns true if table exist in the schema provider, false otherwise. fn table_exist(&self, name: &str) -> bool; } + +impl dyn SchemaProvider { + /// Returns `true` if the schema provider is of type `T`. + /// + /// Prefer this over `downcast_ref::().is_some()`. Works correctly when + /// called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast this schema provider to a concrete type `T`, + /// returning `None` if the provider is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index f4a2338b8eecb..8501ea65902e2 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -17,7 +17,6 @@ //! TableProvider for stream sources, such as FIFO files -use std::any::Any; use std::fmt::Formatter; use std::fs::{File, OpenOptions}; use std::io::BufReader; @@ -28,7 +27,7 @@ use std::sync::Arc; use crate::{Session, TableProvider, TableProviderFactory}; use arrow::array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; use arrow::datatypes::SchemaRef; -use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; +use datafusion_common::{Constraints, DataFusionError, Result, config_err, plan_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -303,10 +302,6 @@ impl StreamTable { #[async_trait] impl TableProvider for StreamTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(self.0.source.schema()) } @@ -405,10 +400,6 @@ impl DisplayAs for StreamWrite { #[async_trait] impl DataSink for StreamWrite { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> &SchemaRef { self.0.source.schema() } diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 082e74dab9a15..e609877c2b778 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -17,22 +17,20 @@ //! A simplified [`TableProvider`] for streaming partitioned datasets -use std::any::Any; use std::sync::Arc; -use crate::Session; -use crate::TableProvider; - use arrow::datatypes::SchemaRef; -use datafusion_common::{plan_err, DFSchema, Result}; +use async_trait::async_trait; +use datafusion_common::{DFSchema, Result, plan_err}; use datafusion_expr::{Expr, SortExpr, TableType}; -use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_expr::equivalence::project_ordering; +use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; - -use async_trait::async_trait; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use log::debug; +use crate::{Session, TableProvider}; + /// A [`TableProvider`] that streams a set of [`PartitionStream`] #[derive(Debug)] pub struct StreamingTable { @@ -82,10 +80,6 @@ impl StreamingTable { #[async_trait] impl TableProvider for StreamingTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -105,7 +99,22 @@ impl TableProvider for StreamingTable { let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; let eqp = state.execution_props(); - create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)? + let original_sort_exprs = + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?; + + if let Some(p) = projection { + // When performing a projection, the output columns will not match + // the original physical sort expression indices. Also the sort columns + // may not be in the output projection. To correct for these issues + // we need to project the ordering based on the output schema. + let schema = Arc::new(self.schema.project(p)?); + LexOrdering::new(original_sort_exprs) + .and_then(|lex_ordering| project_ordering(&lex_ordering, &schema)) + .map(|lex_ordering| lex_ordering.to_vec()) + .unwrap_or_default() + } else { + original_sort_exprs + } } else { vec![] }; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 11c9af01a7a54..5d1391bed1172 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use crate::session::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::Result; -use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_common::{Constraints, Statistics, not_impl_err}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::Expr; use datafusion_expr::dml::InsertOp; @@ -48,11 +48,7 @@ use datafusion_physical_plan::ExecutionPlan; /// [`RecordBatch`]: https://docs.rs/arrow/latest/arrow/record_batch/struct.RecordBatch.html /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] -pub trait TableProvider: Debug + Sync + Send { - /// Returns the table provider as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - +pub trait TableProvider: Any + Debug + Sync + Send { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; @@ -84,10 +80,10 @@ pub trait TableProvider: Debug + Sync + Send { None } - /// Create an [`ExecutionPlan`] for scanning the table with optionally - /// specified `projection`, `filter` and `limit`, described below. + /// Create an [`ExecutionPlan`] for scanning the table with optional + /// `projection`, `filter`, and `limit`, described below. /// - /// The `ExecutionPlan` is responsible scanning the datasource's + /// The returned `ExecutionPlan` is responsible for scanning the datasource's /// partitions in a streaming, parallelized fashion. /// /// # Projection @@ -96,33 +92,30 @@ pub trait TableProvider: Debug + Sync + Send { /// specified. The projection is a set of indexes of the fields in /// [`Self::schema`]. /// - /// DataFusion provides the projection to scan only the columns actually - /// used in the query to improve performance, an optimization called - /// "Projection Pushdown". Some datasources, such as Parquet, can use this - /// information to go significantly faster when only a subset of columns is - /// required. + /// DataFusion provides the projection so the scan reads only the columns + /// actually used in the query, an optimization called "Projection + /// Pushdown". Some datasources, such as Parquet, can use this information + /// to go significantly faster when only a subset of columns is required. /// /// # Filters /// /// A list of boolean filter [`Expr`]s to evaluate *during* the scan, in the /// manner specified by [`Self::supports_filters_pushdown`]. Only rows for - /// which *all* of the `Expr`s evaluate to `true` must be returned (aka the - /// expressions are `AND`ed together). + /// which *all* of the `Expr`s evaluate to `true` must be returned (that is, + /// the expressions are `AND`ed together). /// - /// To enable filter pushdown you must override - /// [`Self::supports_filters_pushdown`] as the default implementation does - /// not and `filters` will be empty. + /// To enable filter pushdown, override + /// [`Self::supports_filters_pushdown`]. The default implementation does not + /// push down filters, and `filters` will be empty. /// - /// DataFusion pushes filtering into the scans whenever possible - /// ("Filter Pushdown"), and depending on the format and the - /// implementation of the format, evaluating the predicate during the scan - /// can increase performance significantly. + /// DataFusion pushes filters into scans whenever possible ("Filter + /// Pushdown"). Depending on the data format and implementation, evaluating + /// predicates during the scan can significantly improve performance. /// /// ## Note: Some columns may appear *only* in Filters /// - /// In certain cases, a query may only use a certain column in a Filter that - /// has been completely pushed down to the scan. In this case, the - /// projection will not contain all the columns found in the filter + /// In some cases, a query may use a column only in a filter and the + /// projection will not contain all columns referenced by the filter /// expressions. /// /// For example, given the query `SELECT t.a FROM t WHERE t.b > 5`, @@ -154,15 +147,40 @@ pub trait TableProvider: Debug + Sync + Send { /// /// # Limit /// - /// If `limit` is specified, must only produce *at least* this many rows, - /// (though it may return more). Like Projection Pushdown and Filter - /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as - /// possible, called "Limit Pushdown" as some sources can use this - /// information to improve their performance. Note that if there are any - /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is - /// because inexact filters do not guarantee that every filtered row is - /// removed, so applying the limit could lead to too few rows being available - /// to return as a final result. + /// If `limit` is specified, the scan must produce *at least* this many + /// rows, though it may return more. Like Projection Pushdown and Filter + /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as + /// possible. This is called "Limit Pushdown", and some sources can use the + /// information to improve performance. + /// + /// Note: If any pushed-down filters are `Inexact`, the `LIMIT` cannot be + /// pushed down. Inexact filters do not guarantee that every filtered row is + /// removed, so applying the limit could leave too few rows to return in the + /// final result. + /// + /// # Evaluation Order + /// + /// The logical evaluation order is `filters`, then `limit`, then + /// `projection`. + /// + /// Note that `limit` applies to the filtered result, not to the unfiltered + /// input, and `projection` affects only which columns are returned, not + /// which rows qualify. + /// + /// For example, if a scan receives: + /// + /// - `projection = [a]` + /// - `filters = [b > 5]` + /// - `limit = Some(3)` + /// + /// It must logically produce results equivalent to: + /// + /// ```text + /// PROJECTION a (LIMIT 3 (SCAN WHERE b > 5)) + /// ``` + /// + /// As noted above, columns referenced only by pushed-down filters may be + /// absent from `projection`. async fn scan( &self, state: &dyn Session, @@ -246,7 +264,6 @@ pub trait TableProvider: Debug + Sync + Send { /// /// #[async_trait] /// impl TableProvider for TestDataSource { - /// # fn as_any(&self) -> &dyn Any { todo!() } /// # fn schema(&self) -> SchemaRef { todo!() } /// # fn table_type(&self) -> TableType { todo!() } /// # async fn scan(&self, s: &dyn Session, p: Option<&Vec>, f: &[Expr], l: Option) -> Result> { @@ -328,6 +345,59 @@ pub trait TableProvider: Debug + Sync + Send { ) -> Result> { not_impl_err!("Insert into not implemented for this table") } + + /// Delete rows matching the filter predicates. + /// + /// Returns an [`ExecutionPlan`] producing a single row with `count` (UInt64). + /// Empty `filters` deletes all rows. + async fn delete_from( + &self, + _state: &dyn Session, + _filters: Vec, + ) -> Result> { + not_impl_err!("DELETE not supported for {} table", self.table_type()) + } + + /// Update rows matching the filter predicates. + /// + /// Returns an [`ExecutionPlan`] producing a single row with `count` (UInt64). + /// Empty `filters` updates all rows. + async fn update( + &self, + _state: &dyn Session, + _assignments: Vec<(String, Expr)>, + _filters: Vec, + ) -> Result> { + not_impl_err!("UPDATE not supported for {} table", self.table_type()) + } + + /// Remove all rows from the table. + /// + /// Should return an [ExecutionPlan] producing a single row with count (UInt64), + /// representing the number of rows removed. + async fn truncate(&self, _state: &dyn Session) -> Result> { + not_impl_err!("TRUNCATE not supported for {} table", self.table_type()) + } +} + +impl dyn TableProvider { + /// Returns `true` if the table provider is of type `T`. + /// + /// Prefer this over `downcast_ref::().is_some()`. Works correctly when + /// called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast this table provider to a concrete type `T`, + /// returning `None` if the provider is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } } /// Arguments for scanning a table with [`TableProvider::scan_with_args`]. @@ -452,10 +522,49 @@ pub trait TableProviderFactory: Debug + Sync + Send { ) -> Result>; } +/// Describes arguments provided to the table function call. +pub struct TableFunctionArgs<'e, 's> { + /// Call arguments. + exprs: &'e [Expr], + /// Session within which the function is called. + session: &'s dyn Session, +} + +impl<'e, 's> TableFunctionArgs<'e, 's> { + /// Make a new [`TableFunctionArgs`]. + pub fn new(exprs: &'e [Expr], session: &'s dyn Session) -> Self { + Self { exprs, session } + } + + /// Get expressions passed as the called function arguments. + pub fn exprs(&self) -> &'e [Expr] { + self.exprs + } + + /// Get a session where the table function is called. + pub fn session(&self) -> &'s dyn Session { + self.session + } +} + /// A trait for table function implementations -pub trait TableFunctionImpl: Debug + Sync + Send { +pub trait TableFunctionImpl: Debug + Sync + Send + Any { + /// Create a table provider + #[deprecated( + since = "53.0.0", + note = "Implement `TableFunctionImpl::call_with_args` instead" + )] + fn call(&self, _exprs: &[Expr]) -> Result> { + internal_err!( + "TableFunctionImpl::call is not implemented. Implement TableFunctionImpl::call_with_args instead." + ) + } + /// Create a table provider - fn call(&self, args: &[Expr]) -> Result>; + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + #[expect(deprecated)] + self.call(args.exprs) + } } /// A table that uses a function to generate data @@ -484,7 +593,20 @@ impl TableFunction { } /// Get the function implementation and generate a table + #[deprecated( + since = "53.0.0", + note = "Use `TableFunction::create_table_provider_with_args` instead" + )] pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + #[expect(deprecated)] self.fun.call(args) } + + /// Get the function implementation and generate a table + pub fn create_table_provider_with_args( + &self, + args: TableFunctionArgs, + ) -> Result> { + self.fun.call_with_args(args) + } } diff --git a/datafusion/catalog/src/view.rs b/datafusion/catalog/src/view.rs index 89c6a4a224511..45084e65f23f2 100644 --- a/datafusion/catalog/src/view.rs +++ b/datafusion/catalog/src/view.rs @@ -17,15 +17,15 @@ //! View data source which uses a LogicalPlan as it's input. -use std::{any::Any, borrow::Cow, sync::Arc}; +use std::{borrow::Cow, sync::Arc}; use crate::Session; use crate::TableProvider; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::error::Result; use datafusion_common::Column; +use datafusion_common::error::Result; use datafusion_expr::TableType; use datafusion_expr::{Expr, LogicalPlan}; use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; @@ -83,10 +83,6 @@ impl ViewTable { #[async_trait] impl TableProvider for ViewTable { - fn as_any(&self) -> &dyn Any { - self - } - fn get_logical_plan(&'_ self) -> Option> { Some(Cow::Borrowed(&self.logical_plan)) } diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index e53d97b41360a..fd9a818bcb1d0 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -31,6 +31,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index cebd6e04cd1b1..ca618b19ed2f1 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -44,7 +44,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn(trace_future(task)); Self { inner } } @@ -56,7 +56,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn_blocking(trace_block(task)); Self { inner } } @@ -115,14 +115,14 @@ impl Drop for SpawnedTask { mod tests { use super::*; - use std::future::{pending, Pending}; + use std::future::{Pending, pending}; use tokio::{runtime::Runtime, sync::oneshot}; #[tokio::test] async fn runtime_shutdown() { let rt = Runtime::new().unwrap(); - #[allow(clippy::async_yields_async)] + #[expect(clippy::async_yields_async)] let task = rt .spawn(async { SpawnedTask::spawn(async { diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index 5d404d99e7760..cf45ccf3ef63a 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -31,5 +32,5 @@ mod trace_utils; pub use common::SpawnedTask; pub use join_set::JoinSet; pub use trace_utils::{ - set_join_set_tracer, trace_block, trace_future, JoinSetTracer, JoinSetTracerError, + JoinSetTracer, JoinSetTracerError, set_join_set_tracer, trace_block, trace_future, }; diff --git a/datafusion/common-runtime/src/trace_utils.rs b/datafusion/common-runtime/src/trace_utils.rs index c3a39c355fc88..f8adbe8825bc1 100644 --- a/datafusion/common-runtime/src/trace_utils.rs +++ b/datafusion/common-runtime/src/trace_utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use std::any::Any; use std::error::Error; use std::fmt::{Display, Formatter, Result as FmtResult}; diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index a9eb0f2220c69..740d4e45b8d05 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -31,6 +31,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -38,49 +41,55 @@ workspace = true name = "datafusion_common" [features] -avro = ["apache-avro"] backtrace = [] parquet_encryption = [ "parquet", "parquet/encryption", "dep:hex", ] -pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] force_hash_collisions = [] recursive_protection = ["dep:recursive"] parquet = ["dep:parquet"] sql = ["sqlparser"] +[[bench]] +harness = false +name = "with_hashes" + +[[bench]] +harness = false +name = "scalar_to_array" + +[[bench]] +harness = false +name = "stats_merge" + [dependencies] -ahash = { workspace = true } -apache-avro = { version = "0.20", default-features = false, features = [ - "bzip", - "snappy", - "xz", - "zstandard", -], optional = true } arrow = { workspace = true } arrow-ipc = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } chrono = { workspace = true } +foldhash = "0.2" half = { workspace = true } hashbrown = { workspace = true } hex = { workspace = true, optional = true } indexmap = { workspace = true } -libc = "0.2.177" +itertools = { workspace = true } +libc = "0.2.185" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } -paste = "1.0.15" -pyo3 = { version = "0.26", optional = true } recursive = { workspace = true, optional = true } sqlparser = { workspace = true, optional = true } tokio = { workspace = true } +uuid = { workspace = true, features = ["v4"] } [target.'cfg(target_family = "wasm")'.dependencies] web-time = "1.1.0" [dev-dependencies] chrono = { workspace = true } +criterion = { workspace = true } insta = { workspace = true } rand = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/common/benches/scalar_to_array.rs b/datafusion/common/benches/scalar_to_array.rs new file mode 100644 index 0000000000000..90a152e515fe5 --- /dev/null +++ b/datafusion/common/benches/scalar_to_array.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `ScalarValue::to_array_of_size`, focusing on List +//! scalars. + +use arrow::array::{Array, ArrayRef, AsArray, StringViewBuilder}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::utils::SingleRowListArrayBuilder; +use std::sync::Arc; + +/// Build a `ScalarValue::List` of `num_elements` Utf8View strings whose +/// inner StringViewArray has `num_buffers` data buffers. +fn make_list_scalar(num_elements: usize, num_buffers: usize) -> ScalarValue { + let elements_per_buffer = num_elements.div_ceil(num_buffers); + + let mut small_arrays: Vec = Vec::new(); + let mut remaining = num_elements; + for buf_idx in 0..num_buffers { + let count = remaining.min(elements_per_buffer); + if count == 0 { + break; + } + let start = buf_idx * elements_per_buffer; + let mut builder = StringViewBuilder::with_capacity(count); + for i in start..start + count { + builder.append_value(format!("{i:024x}")); + } + small_arrays.push(Arc::new(builder.finish()) as ArrayRef); + remaining -= count; + } + + let refs: Vec<&dyn Array> = small_arrays.iter().map(|a| a.as_ref()).collect(); + let concated = arrow::compute::concat(&refs).unwrap(); + + let list_array = SingleRowListArrayBuilder::new(concated) + .with_field(&Field::new_list_field(DataType::Utf8View, true)) + .build_list_array(); + ScalarValue::List(Arc::new(list_array)) +} + +/// We want to measure the cost of doing the conversion and then also accessing +/// the results, to model what would happen during query evaluation. +fn consume_list_array(arr: &ArrayRef) { + let list_arr = arr.as_list::(); + let mut total_len: usize = 0; + for i in 0..list_arr.len() { + let inner = list_arr.value(i); + let sv = inner.as_string_view(); + for j in 0..sv.len() { + total_len += sv.value(j).len(); + } + } + std::hint::black_box(total_len); +} + +fn bench_list_to_array_of_size(c: &mut Criterion) { + let mut group = c.benchmark_group("list_to_array_of_size"); + + let num_elements = 1245; + let scalar_1buf = make_list_scalar(num_elements, 1); + let scalar_50buf = make_list_scalar(num_elements, 50); + + for batch_size in [256, 1024] { + group.bench_with_input( + BenchmarkId::new("1_buffer", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_1buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + group.bench_with_input( + BenchmarkId::new("50_buffers", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_50buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_list_to_array_of_size); +criterion_main!(benches); diff --git a/datafusion/common/benches/stats_merge.rs b/datafusion/common/benches/stats_merge.rs new file mode 100644 index 0000000000000..73229b6379360 --- /dev/null +++ b/datafusion/common/benches/stats_merge.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `Statistics::try_merge_iter`. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + +/// Build a vector of `n` with `num_cols` columns +fn make_stats(n: usize, num_cols: usize) -> Vec { + (0..n) + .map(|i| { + let mut stats = Statistics::default() + .with_num_rows(Precision::Exact(100 + i)) + .with_total_byte_size(Precision::Exact(8000 + i * 80)); + for c in 0..num_cols { + let base = (i * num_cols + c) as i64; + stats = stats.add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(i)) + .with_min_value(Precision::Exact(ScalarValue::Int64(Some(base)))) + .with_max_value(Precision::Exact(ScalarValue::Int64(Some( + base + 1000, + )))) + .with_sum_value(Precision::Exact(ScalarValue::Int64(Some( + base * 100, + )))), + ); + } + stats + }) + .collect() +} + +fn bench_stats_merge(c: &mut Criterion) { + let mut group = c.benchmark_group("stats_merge"); + + for &num_partitions in &[10, 100, 500] { + for &num_cols in &[1, 5, 20] { + let items = make_stats(num_partitions, num_cols); + let schema = Arc::new(Schema::new( + (0..num_cols) + .map(|i| Field::new(format!("col{i}"), DataType::Int64, true)) + .collect::>(), + )); + + let param = format!("{num_partitions}parts_{num_cols}cols"); + + group.bench_with_input( + BenchmarkId::new("try_merge_iter", ¶m), + &(&items, &schema), + |b, (items, schema)| { + b.iter(|| { + std::hint::black_box( + Statistics::try_merge_iter(*items, schema).unwrap(), + ); + }); + }, + ); + } + } + + group.finish(); +} + +criterion_group!(benches, bench_stats_merge); +criterion_main!(benches); diff --git a/datafusion/common/benches/with_hashes.rs b/datafusion/common/benches/with_hashes.rs new file mode 100644 index 0000000000000..0e9c53c896a5e --- /dev/null +++ b/datafusion/common/benches/with_hashes.rs @@ -0,0 +1,569 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `with_hashes` function + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, Int32Array, + Int64Array, ListArray, MapArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, + RunArray, StringViewArray, StructArray, UnionArray, make_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Field, Fields, Int32Type, Int64Type, UnionFields, +}; +use criterion::{Bencher, Criterion, criterion_group, criterion_main}; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::hash_utils::with_hashes; +use rand::Rng; +use rand::SeedableRng; +use rand::distr::{Alphanumeric, Distribution, StandardUniform}; +use rand::prelude::StdRng; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +struct BenchData { + name: &'static str, + array: ArrayRef, + /// Union arrays can't have null bitmasks added + supports_nulls: bool, +} + +fn criterion_benchmark(c: &mut Criterion) { + let pool = StringPool::new(100, 64); + // poll with small strings for string view tests (<=12 bytes are inlined) + let small_pool = StringPool::new(100, 5); + let cases = [ + BenchData { + name: "int64", + array: primitive_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8", + array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "large_utf8", + array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8_view", + array: pool.string_view_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8_view (small)", + array: small_pool.string_view_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "dictionary_utf8_int32", + array: pool.dictionary_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "list_array", + array: list_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "map_array", + array: map_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "sparse_union", + array: sparse_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "dense_union", + array: dense_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "struct_array", + array: create_struct_array(&pool, BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "run_array_int32", + array: create_run_array::(BATCH_SIZE), + supports_nulls: true, + }, + ]; + + for BenchData { + name, + array, + supports_nulls, + } in cases + { + c.bench_function(&format!("{name}: single, no nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&array)); + }); + c.bench_function(&format!("{name}: multiple, no nulls"), |b| { + let arrays = vec![array.clone(), array.clone(), array.clone()]; + do_hash_test(b, &arrays); + }); + // Union arrays can't have null bitmasks + if supports_nulls { + let nullable_array = add_nulls(&array); + c.bench_function(&format!("{name}: single, nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&nullable_array)); + }); + c.bench_function(&format!("{name}: multiple, nulls"), |b| { + let arrays = vec![ + nullable_array.clone(), + nullable_array.clone(), + nullable_array.clone(), + ]; + do_hash_test(b, &arrays); + }); + } + } +} + +fn do_hash_test(b: &mut Bencher, arrays: &[ArrayRef]) { + let state = RandomState::default(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), BATCH_SIZE); // make sure the result is used + Ok(()) + }) + .unwrap(); + }); +} + +fn create_null_mask(len: usize) -> NullBuffer +where + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + let null_density = 0.03; + let mut builder = NullBufferBuilder::new(len); + for _ in 0..len { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_non_null(); + } + } + builder.finish().expect("should be nulls in buffer") +} + +// Returns a new array that is the same as array, but with nulls +// Handles the special case of RunArray where nulls must be in the values array +fn add_nulls(array: &ArrayRef) -> ArrayRef { + use arrow::datatypes::DataType; + + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + // RunArray can't have top-level nulls, so apply nulls to the values array + let run_array = array + .as_any() + .downcast_ref::>() + .expect("Expected RunArray"); + + let run_ends_buffer = run_array.run_ends().inner().clone(); + let run_ends_array = PrimitiveArray::::new(run_ends_buffer, None); + let values = run_array.values().clone(); + + // Add nulls to the values array + let values_with_nulls = { + let array_data = values + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(values.len()))) + .build() + .unwrap(); + make_array(array_data) + }; + + Arc::new( + RunArray::try_new(&run_ends_array, values_with_nulls.as_ref()) + .expect("Failed to create RunArray with null values"), + ) + } + _ => { + let array_data = array + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(array.len()))) + .build() + .unwrap(); + make_array(array_data) + } + } +} + +pub fn make_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// String pool for generating low cardinality data (for dictionaries and string views) +struct StringPool { + strings: Vec, +} + +impl StringPool { + /// Create a new string pool with the given number of random strings + /// each having between 1 and max_length characters. + fn new(pool_size: usize, max_length: usize) -> Self { + let mut rng = make_rng(); + let mut strings = Vec::with_capacity(pool_size); + for _ in 0..pool_size { + let len = rng.random_range(1..=max_length); + let value: Vec = + rng.clone().sample_iter(&Alphanumeric).take(len).collect(); + strings.push(String::from_utf8(value).unwrap()); + } + Self { strings } + } + + /// Return an iterator over &str of the given length with values randomly chosen from the pool + fn iter_strings(&self, len: usize) -> impl Iterator { + let mut rng = make_rng(); + (0..len).map(move |_| { + let idx = rng.random_range(0..self.strings.len()); + self.strings[idx].as_str() + }) + } + + /// Return a StringArray of the given length with values randomly chosen from the pool + fn string_array(&self, array_length: usize) -> ArrayRef { + Arc::new(GenericStringArray::::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a StringViewArray of the given length with values randomly chosen from the pool + fn string_view_array(&self, array_length: usize) -> ArrayRef { + Arc::new(StringViewArray::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a DictionaryArray of the given length with values randomly chosen from the pool + fn dictionary_array( + &self, + array_length: usize, + ) -> ArrayRef { + Arc::new(DictionaryArray::::from_iter( + self.iter_strings(array_length), + )) + } +} + +pub fn primitive_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + let array: PrimitiveArray = (0..array_len) + .map(|_| Some(rng.random::())) + .collect(); + Arc::new(array) +} + +/// Benchmark sliced arrays to demonstrate the optimization for when an array is +/// sliced, the underlying buffer may be much larger than what's referenced by +/// the slice. The optimization avoids hashing unreferenced elements. +fn sliced_array_benchmark(c: &mut Criterion) { + // Test with different slice ratios: slice_size / total_size + // Smaller ratio = more potential savings from the optimization + let slice_ratios = [10, 5, 2]; // 1/10, 1/5, 1/2 of total + + for ratio in slice_ratios { + let total_rows = BATCH_SIZE * ratio; + let slice_offset = BATCH_SIZE * (ratio / 2); // Take from middle + let slice_len = BATCH_SIZE; + + // Sliced ListArray + { + let full_array = list_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("list_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced MapArray + { + let full_array = map_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("map_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced Sparse UnionArray + { + let full_array = sparse_union_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("sparse_union_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + } +} + +fn do_hash_test_with_len(b: &mut Bencher, arrays: &[ArrayRef], expected_len: usize) { + let state = RandomState::default(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), expected_len); + Ok(()) + }) + .unwrap(); + }); +} + +fn list_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let elements_per_row = 5; + let total_elements = num_rows * elements_per_row; + + let values: Int64Array = (0..total_elements) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + )) +} + +fn map_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let entries_per_row = 5; + let total_entries = num_rows * entries_per_row; + + let keys: Int32Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let values: Int64Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * entries_per_row) as i32) + .collect(); + + let entries = StructArray::try_new( + Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ]), + vec![Arc::new(keys), Arc::new(values)], + None, + ) + .unwrap(); + + Arc::new(MapArray::new( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ])), + false, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + None, + false, + )) +} + +fn sparse_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(num_rows), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + None, + children, + ) + .unwrap(), + ) +} + +fn dense_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + + let mut type_counts = vec![0i32; num_types]; + for &tid in &type_ids { + type_counts[tid as usize] += 1; + } + + let mut current_offsets = vec![0i32; num_types]; + let offsets: Vec = type_ids + .iter() + .map(|&tid| { + let offset = current_offsets[tid as usize]; + current_offsets[tid as usize] += 1; + offset + }) + .collect(); + + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(type_counts[i] as usize), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + Some(ScalarBuffer::from(offsets)), + children, + ) + .unwrap(), + ) +} + +fn boolean_array(array_len: usize) -> ArrayRef { + let mut rng = make_rng(); + Arc::new( + (0..array_len) + .map(|_| Some(rng.random::())) + .collect::(), + ) +} + +/// Create a StructArray with multiple columns +fn create_struct_array(pool: &StringPool, array_len: usize) -> ArrayRef { + let bool_array = boolean_array(array_len); + let int32_array = primitive_array::(array_len); + let int64_array = primitive_array::(array_len); + let str_array = pool.string_array::(array_len); + + let fields = Fields::from(vec![ + Field::new("bool_col", DataType::Boolean, false), + Field::new("int32_col", DataType::Int32, false), + Field::new("int64_col", DataType::Int64, false), + Field::new("string_col", DataType::Utf8, false), + ]); + + Arc::new(StructArray::new( + fields, + vec![bool_array, int32_array, int64_array, str_array], + None, + )) +} + +/// Create a RunArray to test run array hashing. +fn create_run_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + // Create runs of varying lengths + let mut run_ends = Vec::new(); + let mut values = Vec::new(); + let mut current_end = 0; + + while current_end < array_len { + // Random run length between 1 and 50 + let run_length = rng.random_range(1..=50).min(array_len - current_end); + current_end += run_length; + run_ends.push(current_end as i32); + values.push(Some(rng.random::())); + } + + let run_ends_array = Arc::new(PrimitiveArray::::from(run_ends)); + let values_array: Arc = + Arc::new(values.into_iter().collect::>()); + + Arc::new( + RunArray::try_new(&run_ends_array, values_array.as_ref()) + .expect("Failed to create RunArray"), + ) +} + +criterion_group!(benches, criterion_benchmark, sliced_array_benchmark); +criterion_main!(benches); diff --git a/datafusion/common/src/alias.rs b/datafusion/common/src/alias.rs index 2ee2cb4dc7add..99f6447a6acd8 100644 --- a/datafusion/common/src/alias.rs +++ b/datafusion/common/src/alias.rs @@ -37,6 +37,16 @@ impl AliasGenerator { Self::default() } + /// Advance the counter to at least `min_id`, ensuring future aliases + /// won't collide with already-existing ones. + /// + /// For example, if the query already contains an alias `alias_42`, then calling + /// `update_min_id(42)` will ensure that future aliases generated by this + /// [`AliasGenerator`] will start from `alias_43`. + pub fn update_min_id(&self, min_id: usize) { + self.next_id.fetch_max(min_id + 1, Ordering::Relaxed); + } + /// Return a unique alias with the provided prefix pub fn next(&self, prefix: &str) -> String { let id = self.next_id.fetch_add(1, Ordering::Relaxed); diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index e6eda3c585e89..bc4313ed95665 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,13 +20,14 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::{downcast_value, Result}; +use crate::{Result, downcast_value}; use arrow::array::{ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, - Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray, - UInt16Array, + Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, + ListViewArray, RunArray, StringViewArray, UInt16Array, }; +use arrow::datatypes::RunEndIndexType; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, @@ -37,8 +38,8 @@ use arrow::{ MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, - UInt8Array, UnionArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt32Array, + UInt64Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; @@ -324,3 +325,18 @@ pub fn as_generic_string_array( ) -> Result<&GenericStringArray> { Ok(downcast_value!(array, GenericStringArray, T)) } + +// Downcast Array to ListViewArray +pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { + Ok(downcast_value!(array, ListViewArray)) +} + +// Downcast Array to LargeListViewArray +pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { + Ok(downcast_value!(array, LargeListViewArray)) +} + +// Downcast Array to RunArray +pub fn as_run_array(array: &dyn Array) -> Result<&RunArray> { + Ok(downcast_value!(array, RunArray, T)) +} diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index f4afdf7002078..2889259dd4820 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -22,8 +22,9 @@ use arrow_ipc::CompressionType; #[cfg(feature = "parquet_encryption")] use crate::encryption::{FileDecryptionProperties, FileEncryptionProperties}; use crate::error::_config_err; -use crate::format::{ExplainAnalyzeLevel, ExplainFormat}; -use crate::parsers::CompressionTypeVariant; +use crate::format::{ExplainAnalyzeCategories, ExplainFormat, MetricType}; +use crate::parquet_config::DFParquetWriterVersion; +use crate::parsers::{CompressionTypeVariant, CsvQuoteStyle}; use crate::utils::get_available_parallelism; use crate::{DataFusionError, Result}; #[cfg(feature = "parquet_encryption")] @@ -157,12 +158,10 @@ macro_rules! config_namespace { // $(#[allow(deprecated)])? { $(let value = $transform(value);)? // Apply transformation if specified - #[allow(deprecated)] let ret = self.$field_name.set(rem, value.as_ref()); $(if !$warn.is_empty() { let default: $field_type = $default; - #[allow(deprecated)] if default != self.$field_name { log::warn!($warn); } @@ -181,14 +180,36 @@ macro_rules! config_namespace { $( let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + $( + stringify!($field_name) => { + { + if rem.is_empty() { + let default_value: $field_type = $default; + self.$field_name = default_value; + Ok(()) + } else { + self.$field_name.reset(rem) + } + } + }, + )* + _ => $crate::error::_config_err!( + "Config value \"{}\" not found on {}", + key, + stringify!($struct_name) + ), + } + } } impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -290,6 +311,15 @@ config_namespace! { /// By default, `nulls_max` is used to follow Postgres's behavior. /// postgres rule: pub default_null_ordering: String, default = "nulls_max".to_string() + + /// When set to true, DataFusion may remove `ORDER BY` clauses from + /// subqueries or CTEs during SQL planning when their ordering cannot + /// affect the result, such as when no `LIMIT` or other + /// order-sensitive operator depends on them. + /// + /// Disable this option to preserve explicit subquery ordering in the + /// planned query. + pub enable_subquery_sort_elimination: bool, default = true } } @@ -448,6 +478,25 @@ config_namespace! { /// metadata memory consumption pub batch_size: usize, default = 8192 + /// A perfect hash join (see `HashJoinExec` for more details) will be considered + /// if the range of keys (max - min) on the build side is < this threshold. + /// This provides a fast path for joins with very small key ranges, + /// bypassing the density check. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_small_build_threshold: usize, default = 1024 + + /// The minimum required density of join keys on the build side to consider a + /// perfect hash join (see `HashJoinExec` for more details). Density is calculated as: + /// `(number of rows) / (max_key - min_key + 1)`. + /// A perfect hash join may be used if the actual key density > this + /// value. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_min_key_density: f64, default = 0.15 + /// When set to true, record batches will be examined between each operator and /// small batches will be coalesced into larger batches. This is helpful when there /// are highly selective filters or joins that could produce tiny output batches. The @@ -517,6 +566,36 @@ config_namespace! { /// batches and merged. pub sort_in_place_threshold_bytes: usize, default = 1024 * 1024 + /// Maximum buffer capacity (in bytes) per partition for BufferExec + /// inserted during sort pushdown optimization. + /// + /// When PushdownSort eliminates a SortExec under SortPreservingMergeExec, + /// a BufferExec is inserted to replace SortExec's buffering role. This + /// prevents I/O stalls by allowing the scan to run ahead of the merge. + /// + /// This uses strictly less memory than the SortExec it replaces (which + /// buffers the entire partition). The buffer respects the global memory + /// pool limit. Setting this to a large value is safe — actual memory + /// usage is bounded by partition size and global memory limits. + pub sort_pushdown_buffer_capacity: usize, default = 1024 * 1024 * 1024 + + /// Maximum size in bytes for individual spill files before rotating to a new file. + /// + /// When operators spill data to disk (e.g., RepartitionExec), they write + /// multiple batches to the same file until this size limit is reached, then rotate + /// to a new file. This reduces syscall overhead compared to one-file-per-batch + /// while preventing files from growing too large. + /// + /// A larger value reduces file creation overhead but may hold more disk space. + /// A smaller value creates more files but allows finer-grained space reclamation + /// as files can be deleted once fully consumed. + /// + /// Now only `RepartitionExec` supports this spill file rotation feature, other spilling operators + /// may create spill files larger than the limit. + /// + /// Default: 128 MB + pub max_spill_file_size_bytes: usize, default = 128 * 1024 * 1024 + /// Number of files to read in parallel when inferring schema and statistics pub meta_fetch_concurrency: usize, default = 32 @@ -589,6 +668,175 @@ config_namespace! { /// written, it may be necessary to increase this size to avoid errors from /// the remote end point. pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 + + /// Whether to enable ANSI SQL mode. + /// + /// The flag is experimental and relevant only for DataFusion Spark built-in functions + /// + /// When `enable_ansi_mode` is set to `true`, the query engine follows ANSI SQL + /// semantics for expressions, casting, and error handling. This means: + /// - **Strict type coercion rules:** implicit casts between incompatible types are disallowed. + /// - **Standard SQL arithmetic behavior:** operations such as division by zero, + /// numeric overflow, or invalid casts raise runtime errors rather than returning + /// `NULL` or adjusted values. + /// - **Consistent ANSI behavior** for string concatenation, comparisons, and `NULL` handling. + /// + /// When `enable_ansi_mode` is `false` (the default), the engine uses a more permissive, + /// non-ANSI mode designed for user convenience and backward compatibility. In this mode: + /// - Implicit casts between types are allowed (e.g., string to integer when possible). + /// - Arithmetic operations are more lenient — for example, `abs()` on the minimum + /// representable integer value returns the input value instead of raising overflow. + /// - Division by zero or invalid casts may return `NULL` instead of failing. + /// + /// # Default + /// `false` — ANSI SQL mode is disabled by default. + pub enable_ansi_mode: bool, default = false + + /// How many bytes to buffer in the probe side of hash joins while the build side is + /// concurrently being built. + /// + /// Without this, hash joins will wait until the full materialization of the build side + /// before polling the probe side. This is useful in scenarios where the query is not + /// completely CPU bounded, allowing to do some early work concurrently and reducing the + /// latency of the query. + /// + /// Note that when hash join buffering is enabled, the probe side will start eagerly + /// polling data, not giving time for the producer side of dynamic filters to produce any + /// meaningful predicate. Queries with dynamic filters might see performance degradation. + /// + /// Disabled by default, set to a number greater than 0 for enabling it. + pub hash_join_buffering_capacity: usize, default = 0 + } +} + +/// Options for content-defined chunking (CDC) when writing parquet files. +/// See [`ParquetOptions::use_content_defined_chunking`]. +/// +/// Can be enabled with default options by setting +/// `use_content_defined_chunking` to `true`, or configured with sub-fields +/// like `use_content_defined_chunking.min_chunk_size`. +#[derive(Debug, Clone, PartialEq)] +pub struct CdcOptions { + /// Minimum chunk size in bytes. The rolling hash will not trigger a split + /// until this many bytes have been accumulated. Default is 256 KiB. + pub min_chunk_size: usize, + + /// Maximum chunk size in bytes. A split is forced when the accumulated + /// size exceeds this value. Default is 1 MiB. + pub max_chunk_size: usize, + + /// Normalization level. Increasing this improves deduplication ratio + /// but increases fragmentation. Recommended range is [-3, 3], default is 0. + pub norm_level: i32, +} + +// Note: `CdcOptions` intentionally does NOT implement `Default` so that the +// blanket `impl ConfigField for Option` does not +// apply. This allows the specific `impl ConfigField for Option` +// below to handle "true"/"false" for enabling/disabling CDC. +// Use `CdcOptions::default()` (the inherent method) instead of `Default::default()`. +impl CdcOptions { + /// Returns a new `CdcOptions` with default values. + #[expect(clippy::should_implement_trait)] + pub fn default() -> Self { + Self { + min_chunk_size: 256 * 1024, + max_chunk_size: 1024 * 1024, + norm_level: 0, + } + } +} + +impl ConfigField for CdcOptions { + fn set(&mut self, key: &str, value: &str) -> Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "min_chunk_size" => self.min_chunk_size.set(rem, value), + "max_chunk_size" => self.max_chunk_size.set(rem, value), + "norm_level" => self.norm_level.set(rem, value), + _ => _config_err!("Config value \"{}\" not found on CdcOptions", key), + } + } + + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.min_chunk_size"); + self.min_chunk_size.visit(v, &key, "Minimum chunk size in bytes. The rolling hash will not trigger a split until this many bytes have been accumulated. Default is 256 KiB."); + let key = format!("{key_prefix}.max_chunk_size"); + self.max_chunk_size.visit(v, &key, "Maximum chunk size in bytes. A split is forced when the accumulated size exceeds this value. Default is 1 MiB."); + let key = format!("{key_prefix}.norm_level"); + self.norm_level.visit(v, &key, "Normalization level. Increasing this improves deduplication ratio but increases fragmentation. Recommended range is [-3, 3], default is 0."); + } + + fn reset(&mut self, key: &str) -> Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "min_chunk_size" => { + if rem.is_empty() { + self.min_chunk_size = CdcOptions::default().min_chunk_size; + Ok(()) + } else { + self.min_chunk_size.reset(rem) + } + } + "max_chunk_size" => { + if rem.is_empty() { + self.max_chunk_size = CdcOptions::default().max_chunk_size; + Ok(()) + } else { + self.max_chunk_size.reset(rem) + } + } + "norm_level" => { + if rem.is_empty() { + self.norm_level = CdcOptions::default().norm_level; + Ok(()) + } else { + self.norm_level.reset(rem) + } + } + _ => _config_err!("Config value \"{}\" not found on CdcOptions", key), + } + } +} + +/// `ConfigField` for `Option` — allows setting the option to +/// `"true"` (enable with defaults) or `"false"` (disable), in addition to +/// setting individual sub-fields like `min_chunk_size`. +impl ConfigField for Option { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + match self { + Some(s) => s.visit(v, key, description), + None => v.none(key, description), + } + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + if key.is_empty() { + match value.to_ascii_lowercase().as_str() { + "true" => { + *self = Some(CdcOptions::default()); + Ok(()) + } + "false" => { + *self = None; + Ok(()) + } + _ => _config_err!( + "Expected 'true' or 'false' for use_content_defined_chunking, got '{value}'" + ), + } + } else { + self.get_or_insert_with(CdcOptions::default).set(key, value) + } + } + + fn reset(&mut self, key: &str) -> Result<()> { + if key.is_empty() { + *self = None; + Ok(()) + } else { + self.get_or_insert_with(CdcOptions::default).reset(key) + } } } @@ -634,6 +882,12 @@ config_namespace! { /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false + /// (reading) Force the use of RowSelections for filter results, when + /// pushdown_filters is enabled. If false, the reader will automatically + /// choose between a RowSelection and a Bitmap based on the number and + /// pattern of selected rows. + pub force_filter_selections: bool, default = false + /// (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, /// and `Binary/BinaryLarge` with `BinaryView`. pub schema_force_view_types: bool, default = true @@ -671,12 +925,12 @@ config_namespace! { /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// (writing) Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in rows pub write_batch_size: usize, default = 1024 /// (writing) Sets parquet writer version /// valid values are "1.0" and "2.0" - pub writer_version: String, default = "1.0".to_string() + pub writer_version: DFParquetWriterVersion, default = DFParquetWriterVersion::default() /// (writing) Skip encoding the embedded arrow metadata in the KV_meta /// @@ -686,7 +940,7 @@ config_namespace! { /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting /// @@ -771,6 +1025,12 @@ config_namespace! { /// writing out already in-memory data, such as from a cached /// data frame. pub maximum_buffered_record_batches_per_stream: usize, default = 2 + + /// (writing) EXPERIMENTAL: Enable content-defined chunking (CDC) when writing + /// parquet files. When `Some`, CDC is enabled with the given options; when `None` + /// (the default), CDC is disabled. When CDC is enabled, parallel writing is + /// automatically disabled since the chunker state must persist across row groups. + pub use_content_defined_chunking: Option, default = None } } @@ -836,6 +1096,20 @@ config_namespace! { /// past window functions, if possible pub enable_window_limits: bool, default = true + /// When set to true, the optimizer will replace + /// Filter(rn<=K) → Window(ROW_NUMBER) → Sort patterns with a + /// PartitionedTopKExec that maintains per-partition heaps, avoiding + /// a full sort of the input. + /// When the window partition key has low cardinality, enabling this optimization + /// can improve performance. However, for high cardinality keys, it may + /// cause regressions in both memory usage and runtime. + pub enable_window_topn: bool, default = false + + /// When set to true, the optimizer will push TopK (Sort with fetch) + /// below hash repartition when the partition key is a prefix of the + /// sort key, reducing data volume before the shuffle. + pub enable_topk_repartition: bool, default = true + /// When set to true, the optimizer will attempt to push down TopK dynamic filters /// into the file scan phase. pub enable_topk_dynamic_filter_pushdown: bool, default = true @@ -844,12 +1118,16 @@ config_namespace! { /// into the file scan phase. pub enable_join_dynamic_filter_pushdown: bool, default = true - /// When set to true attempts to push down dynamic filters generated by operators (topk & join) into the file scan phase. + /// When set to true, the optimizer will attempt to push down Aggregate dynamic filters + /// into the file scan phase. + pub enable_aggregate_dynamic_filter_pushdown: bool, default = true + + /// When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. /// This means that if we already have 10 timestamps in the year 2025 /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. - /// The config will suppress `enable_join_dynamic_filter_pushdown` & `enable_topk_dynamic_filter_pushdown` + /// The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` /// So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. pub enable_dynamic_filter_pushdown: bool, default = true @@ -895,6 +1173,19 @@ config_namespace! { /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true + /// Minimum number of distinct partition values required to group files by their + /// Hive partition column values (enabling Hash partitioning declaration). + /// + /// How the option is used: + /// - preserve_file_partitions=0: Disable it. + /// - preserve_file_partitions=1: Always enable it. + /// - preserve_file_partitions=N, actual file partitions=M: Only enable when M >= N. + /// This threshold preserves I/O parallelism when file partitioning is below it. + /// + /// Note: This may reduce parallelism, rooting from the I/O level, if the number of distinct + /// partitions is less than the target_partitions. + pub preserve_file_partitions: usize, default = 0 + /// Should DataFusion repartition data using the partitions keys to execute window /// functions in parallel using the provided `target_partitions` level pub repartition_windows: bool, default = true @@ -917,6 +1208,34 @@ config_namespace! { /// ``` pub repartition_sorts: bool, default = true + /// Partition count threshold for subset satisfaction optimization. + /// + /// When the current partition count is >= this threshold, DataFusion will + /// skip repartitioning if the required partitioning expression is a subset + /// of the current partition expression such as Hash(a) satisfies Hash(a, b). + /// + /// When the current partition count is < this threshold, DataFusion will + /// repartition to increase parallelism even when subset satisfaction applies. + /// + /// Set to 0 to always repartition (disable subset satisfaction optimization). + /// Set to a high value to always use subset satisfaction. + /// + /// Example (subset_repartition_threshold = 4): + /// ```text + /// Hash([a]) satisfies Hash([a, b]) because (Hash([a, b]) is subset of Hash([a]) + /// + /// If current partitions (3) < threshold (4), repartition: + /// AggregateExec: mode=FinalPartitioned, gby=[a, b], aggr=[SUM(x)] + /// RepartitionExec: partitioning=Hash([a, b], 8), input_partitions=3 + /// AggregateExec: mode=Partial, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 3) + /// + /// If current partitions (8) >= threshold (4), use subset satisfaction: + /// AggregateExec: mode=SinglePartitioned, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 8) + /// ``` + pub subset_repartition_threshold: usize, default = 4 + /// When true, DataFusion will opportunistically remove sorts when the data is already sorted, /// (i.e. setting `preserve_order` to true on `RepartitionExec` and /// using `SortPreservingMergeExec`) @@ -937,6 +1256,18 @@ config_namespace! { /// process to reorder the join keys pub top_down_join_key_reordering: bool, default = true + /// When set to true, the physical plan optimizer may swap join inputs + /// based on statistics. When set to false, statistics-driven join + /// input reordering is disabled and the original join order in the + /// query is used. + pub join_reordering: bool, default = true + + /// When set to true, the physical plan optimizer uses the pluggable + /// `StatisticsRegistry` for statistics propagation across operators. + /// This enables more accurate cardinality estimates compared to each + /// operator's built-in `partition_statistics`. + pub use_statistics_registry: bool, default = false + /// When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true @@ -954,6 +1285,36 @@ config_namespace! { /// will be collected into a single partition pub hash_join_single_partition_threshold_rows: usize, default = 1024 * 128 + /// Maximum size in bytes for the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides larger than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// InList pushdown can be more efficient for small build sides because it can result in better + /// statistics pruning as well as use any bloom filters present on the scan side. + /// InList expressions are also more transparent and easier to serialize over the network in distributed uses of DataFusion. + /// On the other hand InList pushdown requires making a copy of the data and thus adds some overhead to the build side and uses more memory. + /// + /// This setting is per-partition, so we may end up using `hash_join_inlist_pushdown_max_size` * `target_partitions` memory. + /// + /// The default is 128kB per partition. + /// This should allow point lookup joins (e.g. joining on a unique primary key) to use InList pushdown in most cases + /// but avoids excessive memory usage or overhead for larger joins. + pub hash_join_inlist_pushdown_max_size: usize, default = 128 * 1024 + + /// Maximum number of distinct values (rows) in the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides with more rows than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// This provides an additional limit beyond `hash_join_inlist_pushdown_max_size` to prevent + /// very large IN lists that might not provide much benefit over hash table lookups. + /// + /// This uses the deduplicated row count once the build side has been evaluated. + /// + /// The default is 150 values per partition. + /// This is inspired by Trino's `max-filter-keys-per-column` setting. + /// See: + pub hash_join_inlist_pushdown_max_distinct_values: usize, default = 150 + /// The default filter selectivity used by Filter Statistics /// when an exact selectivity cannot be determined. Valid values are /// between 0 (no selectivity) and 100 (all rows are selected). @@ -966,6 +1327,27 @@ config_namespace! { /// then the output will be coerced to a non-view. /// Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. pub expand_views_at_output: bool, default = false + + /// Enable sort pushdown optimization. + /// When enabled, attempts to push sort requirements down to data sources + /// that can natively handle them (e.g., by reversing file/row group read order). + /// + /// Returns **inexact ordering**: Sort operator is kept for correctness, + /// but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), + /// providing significant speedup. + /// + /// Memory: No additional overhead (only changes read order). + /// + /// Future: Will add option to detect perfectly sorted data and eliminate Sort completely. + /// + /// Default: true + pub enable_sort_pushdown: bool, default = true + + /// When set to true, the optimizer will extract leaf expressions + /// (such as `get_field`) from filter/sort/join nodes into projections + /// closer to the leaf table scans, and push those projections down + /// towards the leaf nodes. + pub enable_leaf_expression_pushdown: bool, default = true } } @@ -1003,7 +1385,13 @@ config_namespace! { /// Verbosity level for "EXPLAIN ANALYZE". Default is "dev" /// "summary" shows common metrics for high-level insights. /// "dev" provides deep operator-level introspection for developers. - pub analyze_level: ExplainAnalyzeLevel, default = ExplainAnalyzeLevel::Dev + pub analyze_level: MetricType, default = MetricType::Dev + + /// Which metric categories to include in "EXPLAIN ANALYZE" output. + /// Comma-separated list of: "rows", "bytes", "timing", "uncategorized". + /// Use "none" to show plan structure only, or "all" (default) to show everything. + /// Metrics without a declared category are treated as "uncategorized". + pub analyze_categories: ExplainAnalyzeCategories, default = ExplainAnalyzeCategories::All } } @@ -1046,35 +1434,35 @@ config_namespace! { } } -impl<'a> TryInto> for &'a FormatOptions { +impl<'a> TryFrom<&'a FormatOptions> for arrow::util::display::FormatOptions<'a> { type Error = DataFusionError; - fn try_into(self) -> Result> { - let duration_format = match self.duration_format.as_str() { + fn try_from(options: &'a FormatOptions) -> Result { + let duration_format = match options.duration_format.as_str() { "pretty" => arrow::util::display::DurationFormat::Pretty, "iso8601" => arrow::util::display::DurationFormat::ISO8601, _ => { return _config_err!( "Invalid duration format: {}. Valid values are pretty or iso8601", - self.duration_format - ) + options.duration_format + ); } }; - Ok(arrow::util::display::FormatOptions::new() - .with_display_error(self.safe) - .with_null(&self.null) - .with_date_format(self.date_format.as_deref()) - .with_datetime_format(self.datetime_format.as_deref()) - .with_timestamp_format(self.timestamp_format.as_deref()) - .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) - .with_time_format(self.time_format.as_deref()) + Ok(Self::new() + .with_display_error(options.safe) + .with_null(&options.null) + .with_date_format(options.date_format.as_deref()) + .with_datetime_format(options.datetime_format.as_deref()) + .with_timestamp_format(options.timestamp_format.as_deref()) + .with_timestamp_tz_format(options.timestamp_tz_format.as_deref()) + .with_time_format(options.time_format.as_deref()) .with_duration_format(duration_format) - .with_types_info(self.types_info)) + .with_types_info(options.types_info)) } } /// A key value pair, with a corresponding description -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConfigEntry { /// A unique string to identify this config value pub key: String, @@ -1107,6 +1495,15 @@ pub struct ConfigOptions { } impl ConfigField for ConfigOptions { + fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { + self.catalog.visit(v, "datafusion.catalog", ""); + self.execution.visit(v, "datafusion.execution", ""); + self.optimizer.visit(v, "datafusion.optimizer", ""); + self.explain.visit(v, "datafusion.explain", ""); + self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); + } + fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); @@ -1121,16 +1518,50 @@ impl ConfigField for ConfigOptions { } } - fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.catalog.visit(v, "datafusion.catalog", ""); - self.execution.visit(v, "datafusion.execution", ""); - self.optimizer.visit(v, "datafusion.optimizer", ""); - self.explain.visit(v, "datafusion.explain", ""); - self.sql_parser.visit(v, "datafusion.sql_parser", ""); - self.format.visit(v, "datafusion.format", ""); + /// Reset a configuration option back to its default value + fn reset(&mut self, key: &str) -> Result<()> { + let Some((prefix, rest)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; + + if prefix != "datafusion" { + return _config_err!("Could not find config namespace \"{prefix}\""); + } + + let (section, rem) = rest.split_once('.').unwrap_or((rest, "")); + if rem.is_empty() { + return _config_err!("could not find config field for key \"{key}\""); + } + + match section { + "catalog" => self.catalog.reset(rem), + "execution" => self.execution.reset(rem), + "optimizer" => { + if rem == "enable_dynamic_filter_pushdown" { + let defaults = OptimizerOptions::default(); + self.optimizer.enable_dynamic_filter_pushdown = + defaults.enable_dynamic_filter_pushdown; + self.optimizer.enable_topk_dynamic_filter_pushdown = + defaults.enable_topk_dynamic_filter_pushdown; + self.optimizer.enable_join_dynamic_filter_pushdown = + defaults.enable_join_dynamic_filter_pushdown; + Ok(()) + } else { + self.optimizer.reset(rem) + } + } + "explain" => self.explain.reset(rem), + "sql_parser" => self.sql_parser.reset(rem), + "format" => self.format.reset(rem), + other => _config_err!("Config value \"{other}\" not found on ConfigOptions"), + } } } +/// This namespace is reserved for interacting with Foreign Function Interface +/// (FFI) based configuration extensions. +pub const DATAFUSION_FFI_CONFIG_NAMESPACE: &str = "datafusion_ffi"; + impl ConfigOptions { /// Creates a new [`ConfigOptions`] with default values pub fn new() -> Self { @@ -1145,12 +1576,12 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, key)) = key.split_once('.') else { + let Some((mut prefix, mut inner_key)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; if prefix == "datafusion" { - if key == "optimizer.enable_dynamic_filter_pushdown" { + if inner_key == "optimizer.enable_dynamic_filter_pushdown" { let bool_value = value.parse::().map_err(|e| { DataFusionError::Configuration(format!( "Failed to parse '{value}' as bool: {e}", @@ -1161,16 +1592,27 @@ impl ConfigOptions { self.optimizer.enable_dynamic_filter_pushdown = bool_value; self.optimizer.enable_topk_dynamic_filter_pushdown = bool_value; self.optimizer.enable_join_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_aggregate_dynamic_filter_pushdown = bool_value; } return Ok(()); } - return ConfigField::set(self, key, value); + return ConfigField::set(self, inner_key, value); + } + + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + inner_key = key; + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; } let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; - e.0.set(key, value) + e.0.set(inner_key, value) } /// Create new [`ConfigOptions`], taking values from environment variables @@ -1420,6 +1862,14 @@ impl Extensions { let e = self.0.get_mut(T::PREFIX)?; e.0.as_any_mut().downcast_mut() } + + /// Iterates all the config extension entries yielding their prefix and their + /// [ExtensionOptions] implementation. + pub fn iter( + &self, + ) -> impl Iterator)> { + self.0.iter().map(|(k, v)| (*k, &v.0)) + } } #[derive(Debug)] @@ -1437,6 +1887,10 @@ pub trait ConfigField { fn visit(&self, v: &mut V, key: &str, description: &'static str); fn set(&mut self, key: &str, value: &str) -> Result<()>; + + fn reset(&mut self, key: &str) -> Result<()> { + _config_err!("Reset is not supported for this config field, key: {}", key) + } } impl ConfigField for Option { @@ -1450,6 +1904,15 @@ impl ConfigField for Option { fn set(&mut self, key: &str, value: &str) -> Result<()> { self.get_or_insert_with(Default::default).set(key, value) } + + fn reset(&mut self, key: &str) -> Result<()> { + if key.is_empty() { + *self = Default::default(); + Ok(()) + } else { + self.get_or_insert_with(Default::default).reset(key) + } + } } /// Default transformation to parse a [`ConfigField`] for a string. @@ -1514,6 +1977,19 @@ macro_rules! config_field { *self = $transform; Ok(()) } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + if key.is_empty() { + *self = <$t as Default>::default(); + Ok(()) + } else { + $crate::error::_config_err!( + "Config field is a scalar {} and does not have nested field \"{}\"", + stringify!($t), + key + ) + } + } } }; } @@ -1523,6 +1999,8 @@ config_field!(bool, value => default_config_transform(value.to_lowercase().as_st config_field!(usize); config_field!(f64); config_field!(u64); +config_field!(u32); +config_field!(i32); impl ConfigField for u8 { fn visit(&self, v: &mut V, key: &str, description: &'static str) { @@ -1564,6 +2042,17 @@ impl ConfigField for CompressionTypeVariant { } } +impl ConfigField for CsvQuoteStyle { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = CsvQuoteStyle::from_str(value)?; + Ok(()) + } +} + /// An implementation trait used to recursively walk configuration pub trait Visit { fn some(&mut self, key: &str, value: V, description: &'static str); @@ -1713,8 +2202,7 @@ macro_rules! extensions_options { // Safely apply deprecated attribute if present // $(#[allow(deprecated)])? { - #[allow(deprecated)] - self.$field_name.set(rem, value.as_ref()) + self.$field_name.set(rem, value.as_ref()) } }, )* @@ -1728,7 +2216,6 @@ macro_rules! extensions_options { $( let key = stringify!($field_name).to_string(); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -1902,7 +2389,7 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, _)) = key.split_once('.') else { + let Some((mut prefix, _)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; @@ -1914,6 +2401,15 @@ impl TableOptions { return Ok(()); } + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; + } + let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; @@ -1999,7 +2495,7 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides /// -/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Closely tied to `ParquetWriterOptions` (see `crate::file_options::parquet_writer::ParquetWriterOptions` when the "parquet" feature is enabled). /// Properties not included in [`TableParquetOptions`] may not be configurable at the external API /// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] @@ -2119,13 +2615,13 @@ impl ConfigField for TableParquetOptions { [_meta] | [_meta, ""] => { return _config_err!( "Invalid metadata key provided, missing key in metadata::" - ) + ); } [_meta, k] => k.into(), _ => { return _config_err!( "Invalid metadata key provided, found too many '::' in \"{key}\"" - ) + ); } }; self.key_value_metadata.insert(k, Some(value.into())); @@ -2171,7 +2667,6 @@ macro_rules! config_namespace_with_hashmap { $( stringify!($field_name) => { // Handle deprecated fields - #[allow(deprecated)] // Allow deprecated fields $(let value = $transform(value);)? self.$field_name.set(rem, value.as_ref()) }, @@ -2187,7 +2682,6 @@ macro_rules! config_namespace_with_hashmap { let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); // Handle deprecated fields - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -2195,7 +2689,6 @@ macro_rules! config_namespace_with_hashmap { impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -2223,7 +2716,6 @@ macro_rules! config_namespace_with_hashmap { $( let key = format!("{}.{field}::{}", key_prefix, column_name, field = stringify!($field_name)); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] col_options.$field_name.visit(v, key.as_str(), desc); )* } @@ -2254,7 +2746,7 @@ config_namespace_with_hashmap! { /// Sets default parquet compression codec for the column path. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case-sensitive. If NULL, uses /// default parquet options pub compression: Option, transform = str::to_lowercase, default = None @@ -2437,10 +2929,7 @@ impl From<&Arc> for ConfigFileEncryptionProperties { }, ); } - let mut aad_prefix: Vec = Vec::new(); - if let Some(prefix) = f.aad_prefix() { - aad_prefix = prefix.clone(); - } + let aad_prefix = f.aad_prefix().cloned().unwrap_or_default(); ConfigFileEncryptionProperties { encrypt_footer: f.encrypt_footer(), footer_key_as_hex: hex::encode(f.footer_key()), @@ -2522,7 +3011,7 @@ impl ConfigField for ConfigFileDecryptionProperties { self.footer_signature_verification.set(rem, value.as_ref()) } _ => _config_err!( - "Config value \"{}\" not found on ConfigFileEncryptionProperties", + "Config value \"{}\" not found on ConfigFileDecryptionProperties", key ), } @@ -2564,8 +3053,18 @@ impl From for FileDecryptionProperties { } #[cfg(feature = "parquet_encryption")] -impl From<&Arc> for ConfigFileDecryptionProperties { - fn from(f: &Arc) -> Self { +impl TryFrom<&Arc> for ConfigFileDecryptionProperties { + type Error = DataFusionError; + + fn try_from(f: &Arc) -> Result { + let footer_key = f.footer_key(None).map_err(|e| { + DataFusionError::Configuration(format!( + "Could not retrieve footer key from FileDecryptionProperties. \ + Note that conversion to ConfigFileDecryptionProperties is not supported \ + when using a key retriever: {e}" + )) + })?; + let (column_names_vec, column_keys_vec) = f.column_keys(); let mut column_decryption_properties: HashMap< String, @@ -2578,18 +3077,13 @@ impl From<&Arc> for ConfigFileDecryptionProperties { column_decryption_properties.insert(column_name.clone(), props); } - let mut aad_prefix: Vec = Vec::new(); - if let Some(prefix) = f.aad_prefix() { - aad_prefix = prefix.clone(); - } - ConfigFileDecryptionProperties { - footer_key_as_hex: hex::encode( - f.footer_key(None).unwrap_or_default().as_ref(), - ), + let aad_prefix = f.aad_prefix().cloned().unwrap_or_default(); + Ok(ConfigFileDecryptionProperties { + footer_key_as_hex: hex::encode(footer_key.as_ref()), column_decryption_properties, aad_prefix_as_hex: hex::encode(aad_prefix), footer_signature_verification: f.check_plaintext_footer_integrity(), - } + }) } } @@ -2639,6 +3133,15 @@ config_namespace! { pub terminator: Option, default = None pub escape: Option, default = None pub double_quote: Option, default = None + /// Quote style for CSV writing. + /// One of: "Always", "Necessary", "NonNumeric", "Never" + pub quote_style: CsvQuoteStyle, default = CsvQuoteStyle::Necessary + /// Whether to ignore leading whitespace in string values when writing CSV. + /// Defaults to `false` when `None`. + pub ignore_leading_whitespace: Option, default = None + /// Whether to ignore trailing whitespace in string values when writing CSV. + /// Defaults to `false` when `None`. + pub ignore_trailing_whitespace: Option, default = None /// Specifies whether newlines in (quoted) values are supported. /// /// Parsing newlines in quoted values may be affected by execution behaviour such as @@ -2648,6 +3151,14 @@ config_namespace! { /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None pub date_format: Option, default = None pub datetime_format: Option, default = None @@ -2739,6 +3250,30 @@ impl CsvOptions { self } + /// Set the quote style for CSV writing. + pub fn with_quote_style(mut self, quote_style: CsvQuoteStyle) -> Self { + self.quote_style = quote_style; + self + } + + /// Set whether to ignore leading whitespace in string values when writing CSV. + pub fn with_ignore_leading_whitespace( + mut self, + ignore_leading_whitespace: bool, + ) -> Self { + self.ignore_leading_whitespace = Some(ignore_leading_whitespace); + self + } + + /// Set whether to ignore trailing whitespace in string values when writing CSV. + pub fn with_ignore_trailing_whitespace( + mut self, + ignore_trailing_whitespace: bool, + ) -> Self { + self.ignore_trailing_whitespace = Some(ignore_trailing_whitespace); + self + } + /// Specifies whether newlines in (quoted) values are supported. /// /// Parsing newlines in quoted values may be affected by execution behaviour such as @@ -2770,6 +3305,14 @@ impl CsvOptions { self } + /// Set the compression level for the output file. + /// The valid range depends on the compression algorithm. + /// If not specified, the default level for the algorithm is used. + pub fn with_compression_level(mut self, level: u32) -> Self { + self.compression_level = Some(level); + self + } + /// The delimiter character. pub fn delimiter(&self) -> u8 { self.delimiter @@ -2795,14 +3338,38 @@ config_namespace! { /// Options controlling JSON format pub struct JsonOptions { pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None + /// The JSON format to use when reading files. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, default = true } } pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] -#[allow(clippy::large_enum_variant)] +#[cfg_attr(feature = "parquet", expect(clippy::large_enum_variant))] pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), @@ -2836,7 +3403,6 @@ mod tests { }; use std::any::Any; use std::collections::HashMap; - use std::sync::Arc; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2908,6 +3474,16 @@ mod tests { ); } + #[test] + fn iter_test_extension_config() { + let mut extension = Extensions::new(); + extension.insert(TestExtensionConfig::default()); + let table_config = TableOptions::new().with_extensions(extension); + let extensions = table_config.extensions.iter().collect::>(); + assert_eq!(extensions.len(), 1); + assert_eq!(extensions[0].0, TestExtensionConfig::PREFIX); + } + #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); @@ -2951,6 +3527,19 @@ mod tests { assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 1); } + #[test] + fn reset_nested_scalar_reports_helpful_error() { + let mut value = true; + let err = ::reset(&mut value, "nested").unwrap_err(); + let message = err.to_string(); + assert!( + message.starts_with( + "Invalid or Unsupported Configuration: Config field is a scalar bool and does not have nested field \"nested\"" + ), + "unexpected error message: {message}" + ); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { @@ -2973,6 +3562,7 @@ mod tests { }; use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::encryption::encrypt::FileEncryptionProperties; + use std::sync::Arc; let footer_key = b"0123456789012345".to_vec(); // 128bit/16 let column_names = vec!["double_field", "float_field"]; @@ -2999,7 +3589,8 @@ mod tests { Arc::new(FileEncryptionProperties::from(config_encrypt.clone())); assert_eq!(file_encryption_properties, encryption_properties_built); - let config_decrypt = ConfigFileDecryptionProperties::from(&decryption_properties); + let config_decrypt = + ConfigFileDecryptionProperties::try_from(&decryption_properties).unwrap(); let decryption_properties_built = Arc::new(FileDecryptionProperties::from(config_decrypt.clone())); assert_eq!(decryption_properties, decryption_properties_built); @@ -3117,6 +3708,42 @@ mod tests { assert_eq!(factory_options.get("key2"), Some(&"value 2".to_string())); } + #[cfg(feature = "parquet_encryption")] + struct ParquetEncryptionKeyRetriever {} + + #[cfg(feature = "parquet_encryption")] + impl parquet::encryption::decrypt::KeyRetriever for ParquetEncryptionKeyRetriever { + fn retrieve_key(&self, key_metadata: &[u8]) -> parquet::errors::Result> { + if !key_metadata.is_empty() { + Ok(b"1234567890123450".to_vec()) + } else { + Err(parquet::errors::ParquetError::General( + "Key metadata not provided".to_string(), + )) + } + } + } + + #[cfg(feature = "parquet_encryption")] + #[test] + fn conversion_from_key_retriever_to_config_file_decryption_properties() { + use crate::Result; + use crate::config::ConfigFileDecryptionProperties; + use crate::encryption::FileDecryptionProperties; + + let retriever = std::sync::Arc::new(ParquetEncryptionKeyRetriever {}); + let decryption_properties = + FileDecryptionProperties::with_key_retriever(retriever) + .build() + .unwrap(); + let config_file_decryption_properties: Result = + (&decryption_properties).try_into(); + assert!(config_file_decryption_properties.is_err()); + let err = config_file_decryption_properties.unwrap_err().to_string(); + assert!(err.contains("key retriever")); + assert!(err.contains("Key metadata not provided")); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { @@ -3126,9 +3753,11 @@ mod tests { .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); - assert!(entries - .iter() - .any(|item| item.key == "format.bloom_filter_enabled::col1")) + assert!( + entries + .iter() + .any(|item| item.key == "format.bloom_filter_enabled::col1") + ) } #[cfg(feature = "parquet")] @@ -3142,10 +3771,10 @@ mod tests { ) .unwrap(); let entries = table_parquet_options.entries(); - assert!(entries - .iter() - .any(|item| item.key - == "crypto.file_encryption.column_key_as_hex::double_field")) + assert!( + entries.iter().any(|item| item.key + == "crypto.file_encryption.column_key_as_hex::double_field") + ) } #[cfg(feature = "parquet")] @@ -3181,4 +3810,110 @@ mod tests { let parsed_metadata = table_config.parquet.key_value_metadata; assert_eq!(parsed_metadata.get("key_dupe"), Some(&Some("B".into()))); } + #[cfg(feature = "parquet")] + #[test] + fn test_parquet_writer_version_validation() { + use crate::{config::ConfigOptions, parquet_config::DFParquetWriterVersion}; + + let mut config = ConfigOptions::default(); + + // Valid values should work + config + .set("datafusion.execution.parquet.writer_version", "1.0") + .unwrap(); + assert_eq!( + config.execution.parquet.writer_version, + DFParquetWriterVersion::V1_0 + ); + + config + .set("datafusion.execution.parquet.writer_version", "2.0") + .unwrap(); + assert_eq!( + config.execution.parquet.writer_version, + DFParquetWriterVersion::V2_0 + ); + + // Invalid value should error immediately at SET time + let err = config + .set("datafusion.execution.parquet.writer_version", "3.0") + .unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid or Unsupported Configuration: Invalid parquet writer version: 3.0. Expected one of: 1.0, 2.0" + ); + } + + #[cfg(feature = "parquet")] + #[test] + fn set_cdc_option_with_boolean_true() { + use crate::config::ConfigOptions; + + let mut config = ConfigOptions::default(); + assert!( + config + .execution + .parquet + .use_content_defined_chunking + .is_none() + ); + + // Setting to "true" should enable CDC with default options + config + .set( + "datafusion.execution.parquet.use_content_defined_chunking", + "true", + ) + .unwrap(); + let cdc = config + .execution + .parquet + .use_content_defined_chunking + .as_ref() + .expect("CDC should be enabled"); + assert_eq!(cdc.min_chunk_size, 256 * 1024); + assert_eq!(cdc.max_chunk_size, 1024 * 1024); + assert_eq!(cdc.norm_level, 0); + + // Setting to "false" should disable CDC + config + .set( + "datafusion.execution.parquet.use_content_defined_chunking", + "false", + ) + .unwrap(); + assert!( + config + .execution + .parquet + .use_content_defined_chunking + .is_none() + ); + } + + #[cfg(feature = "parquet")] + #[test] + fn set_cdc_option_with_subfields() { + use crate::config::ConfigOptions; + + let mut config = ConfigOptions::default(); + + // Setting sub-fields should also enable CDC + config + .set( + "datafusion.execution.parquet.use_content_defined_chunking.min_chunk_size", + "1024", + ) + .unwrap(); + let cdc = config + .execution + .parquet + .use_content_defined_chunking + .as_ref() + .expect("CDC should be enabled"); + assert_eq!(cdc.min_chunk_size, 1024); + // Other fields should be defaults + assert_eq!(cdc.max_chunk_size, 1024 * 1024); + assert_eq!(cdc.norm_level, 0); + } } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..93169d6a02ff1 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -19,12 +19,12 @@ //! a [`CSEController`], that defines how to eliminate common subtrees from a particular //! [`TreeNode`] tree. +use crate::Result; use crate::hash_utils::combine_hashes; use crate::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use crate::Result; use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; @@ -676,13 +676,13 @@ where #[cfg(test)] mod test { + use crate::Result; use crate::alias::AliasGenerator; use crate::cse::{ - CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, - Normalizeable, CSE, + CSE, CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, }; use crate::tree_node::tests::TestTreeNode; - use crate::Result; use std::collections::HashSet; use std::hash::{Hash, Hasher}; diff --git a/datafusion/common/src/datatype.rs b/datafusion/common/src/datatype.rs index 65f6395211866..19847f8583505 100644 --- a/datafusion/common/src/datatype.rs +++ b/datafusion/common/src/datatype.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with DataTypes to Fields +//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with Arrow [`DataType`] and [`Field`]s use crate::arrow::datatypes::{DataType, Field, FieldRef}; +use crate::metadata::FieldMetadata; use std::sync::Arc; /// DataFusion extension methods for Arrow [`DataType`] @@ -61,7 +62,54 @@ impl DataTypeExt for DataType { } /// DataFusion extension methods for Arrow [`Field`] and [`FieldRef`] +/// +/// This trait is implemented for both [`Field`] and [`FieldRef`] and +/// provides convenience methods for efficiently working with both types. +/// +/// For [`FieldRef`], the methods will attempt to unwrap the `Arc` +/// to avoid unnecessary cloning when possible. pub trait FieldExt { + /// Ensure the field is named `new_name`, returning the given field if the + /// name matches, and a new field if not. + /// + /// This method avoids `clone`ing fields and names if the name is the same + /// as the field's existing name. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // rename to "your_int" + /// let renamed_field = int_field.renamed("your_int"); + /// assert_eq!(renamed_field.name(), "your_int"); + /// ``` + fn renamed(self, new_name: &str) -> Self; + + /// Ensure the field has the given data type + /// + /// Note this is different than simply calling [`Field::with_data_type`] as + /// it avoids copying if the data type is already the same. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // change to Float64 + /// let retyped_field = int_field.retyped(DataType::Float64); + /// assert_eq!(retyped_field.data_type(), &DataType::Float64); + /// ``` + fn retyped(self, new_data_type: DataType) -> Self; + + /// Add field metadata to the Field + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self; + + /// Add optional field metadata, + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self; + /// Returns a new Field representing a List of this Field's DataType. /// /// For example if input represents an `Int32`, the return value will @@ -130,6 +178,32 @@ pub trait FieldExt { } impl FieldExt for Field { + fn renamed(self, new_name: &str) -> Self { + // check if this is a new name before allocating a new Field / copying + // the existing one + if self.name() != new_name { + self.with_name(new_name) + } else { + self + } + } + + fn retyped(self, new_data_type: DataType) -> Self { + self.with_data_type(new_data_type) + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + fn into_list(self) -> Self { DataType::List(Arc::new(self.into_list_item())).into_nullable_field() } @@ -149,6 +223,34 @@ impl FieldExt for Field { } impl FieldExt for Arc { + fn renamed(mut self, new_name: &str) -> Self { + if self.name() != new_name { + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(new_name); + } + self + } + + fn retyped(mut self, new_data_type: DataType) -> Self { + if self.data_type() != &new_data_type { + // avoid cloning if possible + Arc::make_mut(&mut self).set_data_type(new_data_type); + } + self + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field_ref(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + fn into_list(self) -> Self { DataType::List(self.into_list_item()) .into_nullable_field() @@ -161,13 +263,11 @@ impl FieldExt for Arc { .into() } - fn into_list_item(self) -> Self { + fn into_list_item(mut self) -> Self { if self.name() != Field::LIST_FIELD_DEFAULT_NAME { - Arc::unwrap_or_clone(self) - .with_name(Field::LIST_FIELD_DEFAULT_NAME) - .into() - } else { - self + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(Field::LIST_FIELD_DEFAULT_NAME); } + self } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 24d152a7dba8c..e3da99163ed69 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -21,12 +21,12 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::{Display, Formatter}; use std::hash::Hash; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; -use crate::error::{DataFusionError, Result, _plan_err, _schema_err}; +use crate::error::{_plan_err, _schema_err, DataFusionError, Result}; use crate::{ - field_not_found, unqualified_field_not_found, Column, FunctionalDependencies, - SchemaError, TableReference, + Column, FunctionalDependencies, SchemaError, TableReference, field_not_found, + unqualified_field_not_found, }; use arrow::compute::can_cast_types; @@ -37,7 +37,7 @@ use arrow::datatypes::{ /// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names. +/// DFSchema wraps an Arrow schema and add a relation (table) name. /// /// The schema may hold the fields across multiple tables. Some fields may be /// qualified and some unqualified. A qualified field is a field that has a @@ -47,8 +47,14 @@ pub type DFSchemaRef = Arc; /// have a distinct name from any qualified field names. This allows finding a /// qualified field by name to be possible, so long as there aren't multiple /// qualified fields with the same name. +///] +/// # See Also +/// * [DFSchemaRef], an alias to `Arc` +/// * [DataTypeExt], common methods for working with Arrow [DataType]s +/// * [FieldExt], extension methods for working with Arrow [Field]s /// -/// There is an alias to `Arc` named [DFSchemaRef]. +/// [DataTypeExt]: crate::datatype::DataTypeExt +/// [FieldExt]: crate::datatype::FieldExt /// /// # Creating qualified schemas /// @@ -123,6 +129,13 @@ impl DFSchema { } } + /// Returns a reference to a shared empty [`DFSchema`]. + pub fn empty_ref() -> &'static DFSchemaRef { + static EMPTY: LazyLock = + LazyLock::new(|| Arc::new(DFSchema::empty())); + &EMPTY + } + /// Return a reference to the inner Arrow [`Schema`] /// /// Note this does not have the qualifier information @@ -346,20 +359,22 @@ impl DFSchema { self.field_qualifiers.extend(qualifiers); } - /// Get a list of fields + /// Get a list of fields for this schema pub fn fields(&self) -> &Fields { &self.inner.fields } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &Field { + /// Returns a reference to [`FieldRef`] for a column at specific index + /// within the schema. + /// + /// See also [Self::qualified_field] to get both qualifier and field + pub fn field(&self, i: usize) -> &FieldRef { &self.inner.fields[i] } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector and its qualifier - pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &Field) { + /// Returns the qualifier (if any) and [`FieldRef`] for a column at specific + /// index within the schema. + pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &FieldRef) { (self.field_qualifiers[i].as_ref(), self.field(i)) } @@ -410,12 +425,12 @@ impl DFSchema { .is_some() } - /// Find the field with the given name + /// Find the [`FieldRef`] with the given name and optional qualifier pub fn field_with_name( &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { if let Some(qualifier) = qualifier { self.field_with_qualified_name(qualifier, name) } else { @@ -428,7 +443,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { if let Some(qualifier) = qualifier { let idx = self .index_of_column_by_name(Some(qualifier), name) @@ -440,10 +455,10 @@ impl DFSchema { } /// Find all fields having the given qualifier - pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&Field> { + pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&FieldRef> { self.iter() .filter(|(q, _)| q.map(|q| q.eq(qualifier)).unwrap_or(false)) - .map(|(_, f)| f.as_ref()) + .map(|(_, f)| f) .collect() } @@ -459,11 +474,10 @@ impl DFSchema { } /// Find all fields that match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&Field> { + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&FieldRef> { self.fields() .iter() .filter(|field| field.name() == name) - .map(|f| f.as_ref()) .collect() } @@ -471,10 +485,9 @@ impl DFSchema { pub fn qualified_fields_with_unqualified_name( &self, name: &str, - ) -> Vec<(Option<&TableReference>, &Field)> { + ) -> Vec<(Option<&TableReference>, &FieldRef)> { self.iter() .filter(|(_, field)| field.name() == name) - .map(|(qualifier, field)| (qualifier, field.as_ref())) .collect() } @@ -499,7 +512,7 @@ impl DFSchema { pub fn qualified_field_with_unqualified_name( &self, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), @@ -528,7 +541,7 @@ impl DFSchema { } /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&FieldRef> { self.qualified_field_with_unqualified_name(name) .map(|(_, field)| field) } @@ -538,7 +551,7 @@ impl DFSchema { &self, qualifier: &TableReference, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { let idx = self .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; @@ -550,7 +563,7 @@ impl DFSchema { pub fn qualified_field_from_column( &self, column: &Column, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { self.qualified_field_with_name(column.relation.as_ref(), &column.name) } @@ -692,10 +705,12 @@ impl DFSchema { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() + Self::datatype_is_logically_equal(v1.as_ref(), v2.as_ref()) + } + (DataType::Dictionary(_, v1), othertype) + | (othertype, DataType::Dictionary(_, v1)) => { + Self::datatype_is_logically_equal(v1.as_ref(), othertype) } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { @@ -982,36 +997,35 @@ fn format_field_with_indent( result.push_str(&format!( "{indent}|-- {field_name}: map (nullable = {nullable_str})\n" )); - if let DataType::Struct(inner_fields) = field.data_type() { - if inner_fields.len() == 2 { - format_field_with_indent( - result, - "key", - inner_fields[0].data_type(), - inner_fields[0].is_nullable(), - &child_indent, - ); - let value_contains_null = - field.is_nullable().to_string().to_lowercase(); - // Handle complex value types properly - match inner_fields[1].data_type() { - DataType::Struct(_) - | DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) - | DataType::Map(_, _) => { - format_field_with_indent( - result, - "value", - inner_fields[1].data_type(), - inner_fields[1].is_nullable(), - &child_indent, - ); - } - _ => { - result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", + if let DataType::Struct(inner_fields) = field.data_type() + && inner_fields.len() == 2 + { + format_field_with_indent( + result, + "key", + inner_fields[0].data_type(), + inner_fields[0].is_nullable(), + &child_indent, + ); + let value_contains_null = field.is_nullable().to_string().to_lowercase(); + // Handle complex value types properly + match inner_fields[1].data_type() { + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Map(_, _) => { + format_field_with_indent( + result, + "value", + inner_fields[1].data_type(), + inner_fields[1].is_nullable(), + &child_indent, + ); + } + _ => { + result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", format_simple_data_type(inner_fields[1].data_type()))); - } } } } @@ -1129,6 +1143,12 @@ impl TryFrom for DFSchema { } } +impl From for SchemaRef { + fn from(dfschema: DFSchema) -> Self { + Arc::clone(&dfschema.inner) + } +} + // Hashing refers to a subset of fields considered in PartialEq. impl Hash for DFSchema { fn hash(&self, state: &mut H) { @@ -1221,7 +1241,7 @@ pub trait ExprSchema: std::fmt::Debug { } // Return the column's field - fn field_from_column(&self, col: &Column) -> Result<&Field>; + fn field_from_column(&self, col: &Column) -> Result<&FieldRef>; } // Implement `ExprSchema` for `Arc` @@ -1242,13 +1262,13 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().data_type_and_nullable(col) } - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { self.as_ref().field_from_column(col) } } impl ExprSchema for DFSchema { - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { match &col.relation { Some(r) => self.field_with_qualified_name(r, &col.name), None => self.field_with_unqualified_name(&col.name), @@ -1325,11 +1345,44 @@ impl SchemaExt for Schema { } } +/// Build a fully-qualified field name string. This is equivalent to +/// `format!("{q}.{name}")` when `qualifier` is `Some`, or just `name` when +/// `None`. We avoid going through the `fmt` machinery for performance reasons. pub fn qualified_name(qualifier: Option<&TableReference>, name: &str) -> String { - match qualifier { - Some(q) => format!("{q}.{name}"), - None => name.to_string(), - } + let qualifier = match qualifier { + None => return name.to_string(), + Some(q) => q, + }; + let (first, second, third) = match qualifier { + TableReference::Bare { table } => (table.as_ref(), None, None), + TableReference::Partial { schema, table } => { + (schema.as_ref(), Some(table.as_ref()), None) + } + TableReference::Full { + catalog, + schema, + table, + } => ( + catalog.as_ref(), + Some(schema.as_ref()), + Some(table.as_ref()), + ), + }; + + let extra = second.map_or(0, str::len) + third.map_or(0, str::len); + let mut s = String::with_capacity(first.len() + extra + 3 + name.len()); + s.push_str(first); + if let Some(second) = second { + s.push('.'); + s.push_str(second); + } + if let Some(third) = third { + s.push('.'); + s.push_str(third); + } + s.push('.'); + s.push_str(name); + s } #[cfg(test)] @@ -1338,6 +1391,36 @@ mod tests { use super::*; + /// `qualified_name` doesn't use `TableReference::Display` for performance + /// reasons, but check that the output is consistent. + #[test] + fn qualified_name_agrees_with_display() { + let cases: &[(Option, &str)] = &[ + (None, "col"), + (Some(TableReference::bare("t")), "c0"), + (Some(TableReference::partial("s", "t")), "c0"), + (Some(TableReference::full("c", "s", "t")), "c0"), + (Some(TableReference::bare("mytable")), "some_column_name"), + // Empty segments must be preserved so that distinct qualified + // fields don't collide in `DFSchema::field_names()`. + (Some(TableReference::bare("")), "col"), + (Some(TableReference::partial("s", "")), "col"), + (Some(TableReference::partial("", "t")), "col"), + (Some(TableReference::full("c", "", "t")), "col"), + (Some(TableReference::full("", "s", "t")), "col"), + (Some(TableReference::full("c", "s", "")), "col"), + (Some(TableReference::full("", "", "")), "col"), + ]; + for (qualifier, name) in cases { + let actual = qualified_name(qualifier.as_ref(), name); + let expected = match qualifier { + Some(q) => format!("{q}.{name}"), + None => name.to_string(), + }; + assert_eq!(actual, expected, "qualifier={qualifier:?} name={name}"); + } + } + #[test] fn qualifier_in_name() -> Result<()> { let col = Column::from_name("t1.c0"); @@ -1433,12 +1516,14 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); - assert!(join - .field_with_qualified_name(&TableReference::bare("t2"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); + assert!( + join.field_with_qualified_name(&TableReference::bare("t2"), "c0") + .is_ok() + ); // test invalid access assert!(join.field_with_unqualified_name("c0").is_err()); assert!(join.field_with_unqualified_name("t1.c0").is_err()); @@ -1480,18 +1565,20 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); assert!(join.field_with_unqualified_name("c0").is_ok()); assert!(join.field_with_unqualified_name("c100").is_ok()); assert!(join.field_with_name(None, "c100").is_ok()); // test invalid access assert!(join.field_with_unqualified_name("t1.c0").is_err()); assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join - .field_with_qualified_name(&TableReference::bare(""), "c100") - .is_err()); + assert!( + join.field_with_qualified_name(&TableReference::bare(""), "c100") + .is_err() + ); Ok(()) } @@ -1500,9 +1587,11 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); - assert_contains!(join.unwrap_err().to_string(), - "Schema error: Schema contains qualified \ - field name t1.c0 and unqualified field name c0 which would be ambiguous"); + assert_contains!( + join.unwrap_err().to_string(), + "Schema error: Schema contains qualified \ + field name t1.c0 and unqualified field name c0 which would be ambiguous" + ); Ok(()) } @@ -1781,6 +1870,27 @@ mod tests { &DataType::Utf8, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) )); + + // Dictionary is logically equal to the logically equivalent value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8View, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8, false).into() + )) + ), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8View, false).into() + )) + ) + )); } #[test] @@ -2059,7 +2169,7 @@ mod tests { fn test_print_schema_empty() { let schema = DFSchema::empty(); let output = schema.tree_string(); - insta::assert_snapshot!(output, @r###"root"###); + insta::assert_snapshot!(output, @"root"); } #[test] diff --git a/datafusion/common/src/display/human_readable.rs b/datafusion/common/src/display/human_readable.rs new file mode 100644 index 0000000000000..0e0d677bd8904 --- /dev/null +++ b/datafusion/common/src/display/human_readable.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for rendering sizes, counts, and durations in human readable form. + +/// Common data size units +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} + +/// Present size in human-readable form +pub fn human_readable_size(size: usize) -> String { + use units::*; + + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{value:.1} {unit}") +} + +/// Present count in human-readable form with K, M, B, T suffixes +pub fn human_readable_count(count: usize) -> String { + let count = count as u64; + let (value, unit) = { + if count >= 1_000_000_000_000 { + (count as f64 / 1_000_000_000_000.0, " T") + } else if count >= 1_000_000_000 { + (count as f64 / 1_000_000_000.0, " B") + } else if count >= 1_000_000 { + (count as f64 / 1_000_000.0, " M") + } else if count >= 1_000 { + (count as f64 / 1_000.0, " K") + } else { + return count.to_string(); + } + }; + + // Format with appropriate precision + // For values >= 100, show 1 decimal place (e.g., 123.4 K) + // For values < 100, show 2 decimal places (e.g., 10.12 K) + if value >= 100.0 { + format!("{value:.1}{unit}") + } else { + format!("{value:.2}{unit}") + } +} + +/// Present duration in human-readable form with 2 decimal places +pub fn human_readable_duration(nanos: u64) -> String { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + const NANOS_PER_MILLI: f64 = 1_000_000.0; + const NANOS_PER_MICRO: f64 = 1_000.0; + + let nanos_f64 = nanos as f64; + + if nanos >= 1_000_000_000 { + // >= 1 second: show in seconds + format!("{:.2}s", nanos_f64 / NANOS_PER_SEC) + } else if nanos >= 1_000_000 { + // >= 1 millisecond: show in milliseconds + format!("{:.2}ms", nanos_f64 / NANOS_PER_MILLI) + } else if nanos >= 1_000 { + // >= 1 microsecond: show in microseconds + format!("{:.2}µs", nanos_f64 / NANOS_PER_MICRO) + } else { + // < 1 microsecond: show in nanoseconds + format!("{nanos}ns") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_human_readable_count() { + assert_eq!(human_readable_count(0), "0"); + assert_eq!(human_readable_count(1), "1"); + assert_eq!(human_readable_count(999), "999"); + assert_eq!(human_readable_count(1_000), "1.00 K"); + assert_eq!(human_readable_count(10_100), "10.10 K"); + assert_eq!(human_readable_count(1_532), "1.53 K"); + assert_eq!(human_readable_count(99_999), "100.00 K"); + assert_eq!(human_readable_count(1_000_000), "1.00 M"); + assert_eq!(human_readable_count(1_532_000), "1.53 M"); + assert_eq!(human_readable_count(99_000_000), "99.00 M"); + assert_eq!(human_readable_count(123_456_789), "123.5 M"); + assert_eq!(human_readable_count(1_000_000_000), "1.00 B"); + assert_eq!(human_readable_count(1_532_000_000), "1.53 B"); + assert_eq!(human_readable_count(999_999_999_999), "1000.0 B"); + assert_eq!(human_readable_count(1_000_000_000_000), "1.00 T"); + assert_eq!(human_readable_count(42_000_000_000_000), "42.00 T"); + } + + #[test] + fn test_human_readable_duration() { + assert_eq!(human_readable_duration(0), "0ns"); + assert_eq!(human_readable_duration(1), "1ns"); + assert_eq!(human_readable_duration(999), "999ns"); + assert_eq!(human_readable_duration(1_000), "1.00µs"); + assert_eq!(human_readable_duration(1_234), "1.23µs"); + assert_eq!(human_readable_duration(999_999), "1000.00µs"); + assert_eq!(human_readable_duration(1_000_000), "1.00ms"); + assert_eq!(human_readable_duration(11_295_377), "11.30ms"); + assert_eq!(human_readable_duration(1_234_567), "1.23ms"); + assert_eq!(human_readable_duration(999_999_999), "1000.00ms"); + assert_eq!(human_readable_duration(1_000_000_000), "1.00s"); + assert_eq!(human_readable_duration(1_234_567_890), "1.23s"); + assert_eq!(human_readable_duration(42_000_000_000), "42.00s"); + } +} diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index bad51c45f8ee8..a6a97b243f06a 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -18,6 +18,7 @@ //! Types for plan display mod graphviz; +pub mod human_readable; pub use graphviz::*; use std::{ diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index fde52944d0497..c6c50371c26c1 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -15,7 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion error types +//! # Error Handling in DataFusion +//! +//! In DataFusion, there are two types of errors that can be raised: +//! +//! 1. Expected errors – These indicate invalid operations performed by the caller, +//! such as attempting to open a non-existent file. Different categories exist to +//! distinguish their sources (e.g., [`DataFusionError::ArrowError`], +//! [`DataFusionError::IoError`], etc.). +//! +//! 2. Unexpected errors – Represented by [`DataFusionError::Internal`], these +//! indicate that an internal invariant has been broken, suggesting a potential +//! bug in the system. +//! +//! There are several convenient macros for throwing errors. For example, use +//! `exec_err!` for expected errors. +//! For invariant checks, you can use `assert_or_internal_err!`, +//! `assert_eq_or_internal_err!`, `assert_ne_or_internal_err!` for easier assertions. +//! On the performance-critical path, use `debug_assert!` instead to reduce overhead. + #[cfg(feature = "backtrace")] use std::backtrace::{Backtrace, BacktraceStatus}; @@ -30,8 +48,6 @@ use std::sync::Arc; use crate::utils::datafusion_strsim::normalized_levenshtein; use crate::utils::quote_identifier; use crate::{Column, DFSchema, Diagnostic, TableReference}; -#[cfg(feature = "avro")] -use apache_avro::Error as AvroError; use arrow::error::ArrowError; #[cfg(feature = "parquet")] use parquet::errors::ParquetError; @@ -58,9 +74,6 @@ pub enum DataFusionError { /// Error when reading / writing Parquet data. #[cfg(feature = "parquet")] ParquetError(Box), - /// Error when reading Avro data. - #[cfg(feature = "avro")] - AvroError(Box), /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] ObjectStore(Box), @@ -153,6 +166,10 @@ pub enum DataFusionError { /// to multiple receivers. For example, when the source of a repartition /// errors and the error is propagated to multiple consumers. Shared(Arc), + /// An error that originated during a foreign function interface call. + /// Transferring errors across the FFI boundary is difficult, so the original + /// error will be converted to a string. + Ffi(String), } #[macro_export] @@ -310,13 +327,6 @@ impl From for DataFusionError { } } -#[cfg(feature = "avro")] -impl From for DataFusionError { - fn from(e: AvroError) -> Self { - DataFusionError::AvroError(Box::new(e)) - } -} - #[cfg(feature = "object_store")] impl From for DataFusionError { fn from(e: object_store::Error) -> Self { @@ -367,8 +377,6 @@ impl Error for DataFusionError { DataFusionError::ArrowError(e, _) => Some(e.as_ref()), #[cfg(feature = "parquet")] DataFusionError::ParquetError(e) => Some(e.as_ref()), - #[cfg(feature = "avro")] - DataFusionError::AvroError(e) => Some(e.as_ref()), #[cfg(feature = "object_store")] DataFusionError::ObjectStore(e) => Some(e.as_ref()), DataFusionError::IoError(e) => Some(e), @@ -395,6 +403,7 @@ impl Error for DataFusionError { // can't be executed. DataFusionError::Collection(errs) => errs.first().map(|e| e as &dyn Error), DataFusionError::Shared(e) => Some(e.as_ref()), + DataFusionError::Ffi(_) => None, } } } @@ -497,8 +506,6 @@ impl DataFusionError { DataFusionError::ArrowError(_, _) => "Arrow error: ", #[cfg(feature = "parquet")] DataFusionError::ParquetError(_) => "Parquet error: ", - #[cfg(feature = "avro")] - DataFusionError::AvroError(_) => "Avro error: ", #[cfg(feature = "object_store")] DataFusionError::ObjectStore(_) => "Object Store error: ", DataFusionError::IoError(_) => "IO error: ", @@ -526,6 +533,7 @@ impl DataFusionError { errs.first().expect("cannot construct DataFusionError::Collection with 0 errors, but got one such case").error_prefix() } DataFusionError::Shared(_) => "", + DataFusionError::Ffi(_) => "FFI error: ", } } @@ -537,8 +545,6 @@ impl DataFusionError { } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => Cow::Owned(desc.to_string()), - #[cfg(feature = "avro")] - DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), #[cfg(feature = "sql")] DataFusionError::SQL(ref desc, ref backtrace) => { @@ -578,6 +584,7 @@ impl DataFusionError { .expect("cannot construct DataFusionError::Collection with 0 errors") .message(), DataFusionError::Shared(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::Ffi(ref desc) => Cow::Owned(desc.to_string()), } } @@ -750,7 +757,7 @@ impl DataFusionErrorBuilder { macro_rules! unwrap_or_internal_err { ($Value: ident) => { $Value.ok_or_else(|| { - DataFusionError::Internal(format!( + $crate::DataFusionError::Internal(format!( "{} should not be None", stringify!($Value) )) @@ -758,6 +765,116 @@ macro_rules! unwrap_or_internal_err { }; } +/// Assert a condition, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_or_internal_err!(predicate); +/// assert_or_internal_err!(predicate, "human readable message"); +/// assert_or_internal_err!(predicate, format!("details: {}", value)); +/// ``` +#[macro_export] +macro_rules! assert_or_internal_err { + ($cond:expr) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}", + stringify!($cond) + ))); + } + }; + ($cond:expr, $($arg:tt)+) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}: {}", + stringify!($cond), + format!($($arg)+) + ))); + } + }; +} + +/// Assert equality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_eq_or_internal_err!(actual, expected); +/// assert_eq_or_internal_err!(left_expr, right_expr, "values must match"); +/// assert_eq_or_internal_err!(lhs, rhs, "metadata: {}", extra); +/// ``` +#[macro_export] +macro_rules! assert_eq_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + +/// Assert inequality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_ne_or_internal_err!(left, right); +/// assert_ne_or_internal_err!(lhs_expr, rhs_expr, "values must differ"); +/// assert_ne_or_internal_err!(a, b, "context {}", info); +/// ``` +#[macro_export] +macro_rules! assert_ne_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + /// Add a macros for concise DataFusionError::* errors declaration /// supports placeholders the same way as `format!` /// Examples: @@ -768,84 +885,131 @@ macro_rules! unwrap_or_internal_err { /// plan_err!("Error {val:?}") /// /// `NAME_ERR` - macro name for wrapping Err(DataFusionError::*) +/// `PREFIXED_NAME_ERR` - underscore-prefixed alias for NAME_ERR (e.g., _plan_err) +/// (Needed to avoid compiler error when using macro in the same crate: `macros from the current crate cannot be referred to by absolute paths`) /// `NAME_DF_ERR` - macro name for wrapping DataFusionError::*. Needed to keep backtrace opportunity /// in construction where DataFusionError::* used directly, like `map_err`, `ok_or_else`, etc +/// `PREFIXED_NAME_DF_ERR` - underscore-prefixed alias for NAME_DF_ERR (e.g., _plan_datafusion_err). +/// (Needed to avoid compiler error when using macro in the same crate: `macros from the current crate cannot be referred to by absolute paths`) macro_rules! make_error { - ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { make_error!(@inner ($), $NAME_ERR, $NAME_DF_ERR, $ERR); }; - (@inner ($d:tt), $NAME_ERR:ident, $NAME_DF_ERR:ident, $ERR:ident) => { - ::paste::paste!{ - /// Macro wraps `$ERR` to add backtrace feature - #[macro_export] - macro_rules! $NAME_DF_ERR { - ($d($d args:expr),* $d(; diagnostic=$d DIAG:expr)?) => {{ - let err =$crate::DataFusionError::$ERR( - ::std::format!( - "{}{}", - ::std::format!($d($d args),*), - $crate::DataFusionError::get_back_trace(), - ).into() - ); - $d ( - let err = err.with_diagnostic($d DIAG); - )? - err - } - } + ($NAME_ERR:ident, $PREFIXED_NAME_ERR:ident, $NAME_DF_ERR:ident, $PREFIXED_NAME_DF_ERR:ident, $ERR:ident) => { + make_error!(@inner ($), $NAME_ERR, $PREFIXED_NAME_ERR, $NAME_DF_ERR, $PREFIXED_NAME_DF_ERR, $ERR); + }; + (@inner ($d:tt), $NAME_ERR:ident, $PREFIXED_NAME_ERR:ident, $NAME_DF_ERR:ident, $PREFIXED_NAME_DF_ERR:ident, $ERR:ident) => { + /// Macro wraps `$ERR` to add backtrace feature + #[macro_export] + macro_rules! $NAME_DF_ERR { + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::DataFusionError::$ERR( + ::std::format!( + "{}{}", + ::std::format!($d($d args),*), + $crate::DataFusionError::get_back_trace(), + ).into() + ); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + err + }} } - /// Macro wraps Err(`$ERR`) to add backtrace feature - #[macro_export] - macro_rules! $NAME_ERR { - ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ - let err = $crate::[<_ $NAME_DF_ERR>]!($d($d args),*); - $d ( - let err = err.with_diagnostic($d DIAG); - )? - Err(err) - - }} - } - - - // Note: Certain macros are used in this crate, but not all. - // This macro generates a use or all of them in case they are needed - // so we allow unused code to avoid warnings when they are not used - #[doc(hidden)] - #[allow(unused)] - pub use $NAME_ERR as [<_ $NAME_ERR>]; - #[doc(hidden)] - #[allow(unused)] - pub use $NAME_DF_ERR as [<_ $NAME_DF_ERR>]; + /// Macro wraps Err(`$ERR`) to add backtrace feature + #[macro_export] + macro_rules! $NAME_ERR { + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::$PREFIXED_NAME_DF_ERR!($d($d args),*); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + Err(err) + }} } + + #[doc(hidden)] + pub use $NAME_ERR as $PREFIXED_NAME_ERR; + #[doc(hidden)] + pub use $NAME_DF_ERR as $PREFIXED_NAME_DF_ERR; }; } // Exposes a macro to create `DataFusionError::Plan` with optional backtrace -make_error!(plan_err, plan_datafusion_err, Plan); +make_error!( + plan_err, + _plan_err, + plan_datafusion_err, + _plan_datafusion_err, + Plan +); // Exposes a macro to create `DataFusionError::Internal` with optional backtrace -make_error!(internal_err, internal_datafusion_err, Internal); +make_error!( + internal_err, + _internal_err, + internal_datafusion_err, + _internal_datafusion_err, + Internal +); // Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace -make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); +make_error!( + not_impl_err, + _not_impl_err, + not_impl_datafusion_err, + _not_impl_datafusion_err, + NotImplemented +); // Exposes a macro to create `DataFusionError::Execution` with optional backtrace -make_error!(exec_err, exec_datafusion_err, Execution); +make_error!( + exec_err, + _exec_err, + exec_datafusion_err, + _exec_datafusion_err, + Execution +); // Exposes a macro to create `DataFusionError::Configuration` with optional backtrace -make_error!(config_err, config_datafusion_err, Configuration); +make_error!( + config_err, + _config_err, + config_datafusion_err, + _config_datafusion_err, + Configuration +); // Exposes a macro to create `DataFusionError::Substrait` with optional backtrace -make_error!(substrait_err, substrait_datafusion_err, Substrait); +make_error!( + substrait_err, + _substrait_err, + substrait_datafusion_err, + _substrait_datafusion_err, + Substrait +); // Exposes a macro to create `DataFusionError::ResourcesExhausted` with optional backtrace -make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); +make_error!( + resources_err, + _resources_err, + resources_datafusion_err, + _resources_datafusion_err, + ResourcesExhausted +); + +// Exposes a macro to create `DataFusionError::Ffi` with optional backtrace +make_error!( + ffi_err, + _ffi_err, + ffi_datafusion_err, + _ffi_datafusion_err, + Ffi +); // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::SQL(Box::new($ERR), Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::SQL(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -857,7 +1021,7 @@ macro_rules! sql_datafusion_err { #[macro_export] macro_rules! sql_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = datafusion_common::sql_datafusion_err!($ERR); + let err = $crate::sql_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -869,7 +1033,7 @@ macro_rules! sql_err { #[macro_export] macro_rules! arrow_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::ArrowError(Box::new($ERR), Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::ArrowError(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -882,7 +1046,7 @@ macro_rules! arrow_datafusion_err { macro_rules! arrow_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { { - let err = datafusion_common::arrow_datafusion_err!($ERR); + let err = $crate::arrow_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -894,9 +1058,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( + let err = $crate::DataFusionError::SchemaError( Box::new($ERR), - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -909,9 +1073,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( + let err = $crate::DataFusionError::SchemaError( Box::new($ERR), - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -974,6 +1138,115 @@ mod test { use std::sync::Arc; use arrow::error::ArrowError; + use insta::assert_snapshot; + + fn ok_result() -> Result<()> { + Ok(()) + } + + #[test] + fn test_assert_eq_or_internal_err_passes() -> Result<()> { + assert_eq_or_internal_err!(1, 1); + ok_result() + } + + #[test] + fn test_assert_eq_or_internal_err_fails() { + fn check() -> Result<()> { + assert_eq_or_internal_err!(1, 2, "expected equality"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 1 == 2 (left: 1, right: 2): expected equality. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_ne_or_internal_err_passes() -> Result<()> { + assert_ne_or_internal_err!(1, 2); + ok_result() + } + + #[test] + fn test_assert_ne_or_internal_err_fails() { + fn check() -> Result<()> { + assert_ne_or_internal_err!(3, 3, "values must differ"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 3 != 3 (left: 3, right: 3): values must differ. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_passes() -> Result<()> { + assert_or_internal_err!(true); + assert_or_internal_err!(true, "message"); + ok_result() + } + + #[test] + fn test_assert_or_internal_err_fails_default() { + fn check() -> Result<()> { + assert_or_internal_err!(false); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_fails_with_message() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom message"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom message. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_with_format_arguments() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom {}", 42); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom 42. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } #[test] fn test_error_size() { @@ -986,9 +1259,10 @@ mod test { #[test] fn datafusion_error_to_arrow() { let res = return_arrow_error().unwrap_err(); - assert!(res - .to_string() - .starts_with("External error: Error during planning: foo")); + assert!( + res.to_string() + .starts_with("External error: Error during planning: foo") + ); } #[test] @@ -1000,7 +1274,6 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} @@ -1017,17 +1290,17 @@ mod test { .unwrap(), &"Error during planning: Err" ); - assert!(!err - .split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .is_empty()); + assert!( + !err.split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty() + ); } #[cfg(not(feature = "backtrace"))] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_disabled_backtrace() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err().to_string(); @@ -1097,7 +1370,6 @@ mod test { } #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_make_error_parse_input() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err(); @@ -1166,9 +1438,11 @@ mod test { let external_error_2: DataFusionError = generic_error_2.into(); println!("{external_error_2}"); - assert!(external_error_2 - .to_string() - .starts_with("External error: io error")); + assert!( + external_error_2 + .to_string() + .starts_with("External error: io error") + ); } /// Model what happens when implementing SendableRecordBatchStream: diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 943288af91642..fa116d17277cc 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -31,6 +31,8 @@ pub struct CsvWriterOptions { /// Compression to apply after ArrowWriter serializes RecordBatches. /// This compression is applied by DataFusion not the ArrowWriter itself. pub compression: CompressionTypeVariant, + /// Compression level for the output file. + pub compression_level: Option, } impl CsvWriterOptions { @@ -41,6 +43,20 @@ impl CsvWriterOptions { Self { writer_options, compression, + compression_level: None, + } + } + + /// Create a new `CsvWriterOptions` with the specified compression level. + pub fn new_with_level( + writer_options: WriterBuilder, + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + writer_options, + compression, + compression_level: Some(compression_level), } } } @@ -78,9 +94,17 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { if let Some(v) = &value.double_quote { builder = builder.with_double_quote(*v) } + builder = builder.with_quote_style(value.quote_style.into()); + if let Some(v) = &value.ignore_leading_whitespace { + builder = builder.with_ignore_leading_whitespace(*v) + } + if let Some(v) = &value.ignore_trailing_whitespace { + builder = builder.with_ignore_trailing_whitespace(*v) + } Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/json_writer.rs b/datafusion/common/src/file_options/json_writer.rs index 750d2972329bb..a537192c8128a 100644 --- a/datafusion/common/src/file_options/json_writer.rs +++ b/datafusion/common/src/file_options/json_writer.rs @@ -27,11 +27,26 @@ use crate::{ #[derive(Clone, Debug)] pub struct JsonWriterOptions { pub compression: CompressionTypeVariant, + pub compression_level: Option, } impl JsonWriterOptions { pub fn new(compression: CompressionTypeVariant) -> Self { - Self { compression } + Self { + compression, + compression_level: None, + } + } + + /// Create a new `JsonWriterOptions` with the specified compression and level. + pub fn new_with_level( + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + compression, + compression_level: Some(compression_level), + } } } @@ -41,6 +56,7 @@ impl TryFrom<&JsonOptions> for JsonWriterOptions { fn try_from(value: &JsonOptions) -> Result { Ok(JsonWriterOptions { compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 02667e0165717..5d2abd23172ed 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -31,10 +31,10 @@ mod tests { use std::collections::HashMap; use crate::{ + Result, config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, }; use parquet::{ @@ -84,7 +84,7 @@ mod tests { .build(); // Verify the expected options propagated down to parquet crate WriterProperties struct - assert_eq!(properties.max_row_group_size(), 123); + assert_eq!(properties.max_row_group_row_count(), Some(123)); assert_eq!(properties.data_page_size_limit(), 123); assert_eq!(properties.write_batch_size(), 123); assert_eq!(properties.writer_version(), WriterVersion::PARQUET_2_0); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 564929c61bab0..eaf5a1642e8e2 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -20,22 +20,20 @@ use std::sync::Arc; use crate::{ + _internal_datafusion_err, DataFusionError, Result, config::{ParquetOptions, TableParquetOptions}, - DataFusionError, Result, _internal_datafusion_err, }; use arrow::datatypes::Schema; use parquet::arrow::encode_arrow_schema; -// TODO: handle once deprecated -#[allow(deprecated)] use parquet::{ arrow::ARROW_SCHEMA_META_KEY, basic::{BrotliLevel, GzipLevel, ZstdLevel}, file::{ metadata::KeyValue, properties::{ - EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, - DEFAULT_STATISTICS_ENABLED, + DEFAULT_STATISTICS_ENABLED, EnabledStatistics, WriterProperties, + WriterPropertiesBuilder, }, }, schema::types::ColumnPath, @@ -97,7 +95,7 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { global, column_specific_options, key_value_metadata, - crypto: _, + .. } = table_parquet_options; let mut builder = global.into_writer_properties_builder()?; @@ -106,7 +104,9 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { if !global.skip_arrow_metadata && !key_value_metadata.contains_key(ARROW_SCHEMA_META_KEY) { - return Err(_internal_datafusion_err!("arrow schema was not added to the kv_metadata, even though it is required by configuration settings")); + return Err(_internal_datafusion_err!( + "arrow schema was not added to the kv_metadata, even though it is required by configuration settings" + )); } // add kv_meta, if any @@ -174,7 +174,6 @@ impl ParquetOptions { /// /// Note that this method does not include the key_value_metadata from [`TableParquetOptions`]. pub fn into_writer_properties_builder(&self) -> Result { - #[allow(deprecated)] let ParquetOptions { data_pagesize_limit, write_batch_size, @@ -192,6 +191,7 @@ impl ParquetOptions { bloom_filter_on_write, bloom_filter_fpp, bloom_filter_ndv, + use_content_defined_chunking, // not in WriterProperties enable_page_index: _, @@ -200,6 +200,7 @@ impl ParquetOptions { metadata_size_hint: _, pushdown_filters: _, reorder_filters: _, + force_filter_selections: _, // not used for writer props allow_single_file_parallelism: _, maximum_parallel_row_group_writers: _, maximum_buffered_record_batches_per_stream: _, @@ -214,7 +215,7 @@ impl ParquetOptions { let mut builder = WriterProperties::builder() .set_data_page_size_limit(*data_pagesize_limit) .set_write_batch_size(*write_batch_size) - .set_writer_version(parse_version_string(writer_version.as_str())?) + .set_writer_version((*writer_version).into()) .set_dictionary_page_size_limit(*dictionary_page_size_limit) .set_statistics_enabled( statistics_enabled @@ -222,7 +223,7 @@ impl ParquetOptions { .and_then(|s| parse_statistics_string(s).ok()) .unwrap_or(DEFAULT_STATISTICS_ENABLED), ) - .set_max_row_group_size(*max_row_group_size) + .set_max_row_group_row_count(Some(*max_row_group_size)) .set_created_by(created_by.clone()) .set_column_index_truncate_length(*column_index_truncate_length) .set_statistics_truncate_length(*statistics_truncate_length) @@ -247,6 +248,26 @@ impl ParquetOptions { if let Some(encoding) = encoding { builder = builder.set_encoding(parse_encoding_string(encoding)?); } + if let Some(cdc) = use_content_defined_chunking { + if cdc.min_chunk_size == 0 { + return Err(DataFusionError::Configuration( + "CDC min_chunk_size must be greater than 0".to_string(), + )); + } + if cdc.max_chunk_size <= cdc.min_chunk_size { + return Err(DataFusionError::Configuration(format!( + "CDC max_chunk_size ({}) must be greater than min_chunk_size ({})", + cdc.max_chunk_size, cdc.min_chunk_size + ))); + } + builder = builder.set_content_defined_chunking(Some( + parquet::file::properties::CdcOptions { + min_chunk_size: cdc.min_chunk_size, + max_chunk_size: cdc.max_chunk_size, + norm_level: cdc.norm_level, + }, + )); + } Ok(builder) } @@ -261,7 +282,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), - #[allow(deprecated)] + #[expect(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { @@ -341,10 +362,6 @@ pub fn parse_compression_string( level, )?)) } - "lzo" => { - check_level_is_none(codec, &level)?; - Ok(parquet::basic::Compression::LZO) - } "brotli" => { let level = require_level(codec, level)?; Ok(parquet::basic::Compression::BROTLI(BrotliLevel::try_new( @@ -368,19 +385,7 @@ pub fn parse_compression_string( _ => Err(DataFusionError::Configuration(format!( "Unknown or unsupported parquet compression: \ {str_setting}. Valid values are: uncompressed, snappy, gzip(level), \ - lzo, brotli(level), lz4, zstd(level), and lz4_raw." - ))), - } -} - -pub(crate) fn parse_version_string(str_setting: &str) -> Result { - let str_setting_lower: &str = &str_setting.to_lowercase(); - match str_setting_lower { - "1.0" => Ok(WriterVersion::PARQUET_1_0), - "2.0" => Ok(WriterVersion::PARQUET_2_0), - _ => Err(DataFusionError::Configuration(format!( - "Unknown or unsupported parquet writer version {str_setting} \ - valid options are 1.0 and 2.0" + brotli(level), lz4, zstd(level), and lz4_raw." ))), } } @@ -402,14 +407,16 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetOptions { let defaults = ParquetOptions::default(); - let writer_version = if defaults.writer_version.eq("1.0") { - "2.0" + let writer_version = if defaults.writer_version.eq(&DFParquetWriterVersion::V1_0) + { + DFParquetWriterVersion::V2_0 } else { - "1.0" + DFParquetWriterVersion::V1_0 }; - #[allow(deprecated)] // max_statistics_size ParquetOptions { data_pagesize_limit: 42, write_batch_size: 42, - writer_version: writer_version.into(), + writer_version, compression: Some("zstd(22)".into()), dictionary_enabled: Some(!defaults.dictionary_enabled.unwrap_or(false)), dictionary_page_size_limit: 42, @@ -464,6 +471,7 @@ mod tests { metadata_size_hint: defaults.metadata_size_hint, pushdown_filters: defaults.pushdown_filters, reorder_filters: defaults.reorder_filters, + force_filter_selections: defaults.force_filter_selections, allow_single_file_parallelism: defaults.allow_single_file_parallelism, maximum_parallel_row_group_writers: defaults .maximum_parallel_row_group_writers, @@ -475,6 +483,7 @@ mod tests { skip_arrow_metadata: defaults.skip_arrow_metadata, coerce_int96: None, max_predicate_cache_size: defaults.max_predicate_cache_size, + use_content_defined_chunking: defaults.use_content_defined_chunking.clone(), } } @@ -484,7 +493,6 @@ mod tests { ) -> ParquetColumnOptions { let bloom_filter_default_props = props.bloom_filter_properties(&col); - #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { bloom_filter_enabled: Some(bloom_filter_default_props.is_some()), encoding: props.encoding(&col).map(|s| s.to_string()), @@ -545,15 +553,16 @@ mod tests { #[cfg(not(feature = "parquet_encryption"))] let fep = None; - #[allow(deprecated)] // max_statistics_size TableParquetOptions { global: ParquetOptions { // global options data_pagesize_limit: props.dictionary_page_size_limit(), write_batch_size: props.write_batch_size(), - writer_version: format!("{}.0", props.writer_version().as_num()), + writer_version: props.writer_version().into(), dictionary_page_size_limit: props.dictionary_page_size_limit(), - max_row_group_size: props.max_row_group_size(), + max_row_group_size: props + .max_row_group_row_count() + .unwrap_or(DEFAULT_MAX_ROW_GROUP_ROW_COUNT), created_by: props.created_by().to_string(), column_index_truncate_length: props.column_index_truncate_length(), statistics_truncate_length: props.statistics_truncate_length(), @@ -577,6 +586,7 @@ mod tests { metadata_size_hint: global_options_defaults.metadata_size_hint, pushdown_filters: global_options_defaults.pushdown_filters, reorder_filters: global_options_defaults.reorder_filters, + force_filter_selections: global_options_defaults.force_filter_selections, allow_single_file_parallelism: global_options_defaults .allow_single_file_parallelism, maximum_parallel_row_group_writers: global_options_defaults @@ -590,6 +600,13 @@ mod tests { binary_as_string: global_options_defaults.binary_as_string, skip_arrow_metadata: global_options_defaults.skip_arrow_metadata, coerce_int96: None, + use_content_defined_chunking: props.content_defined_chunking().map(|c| { + CdcOptions { + min_chunk_size: c.min_chunk_size, + max_chunk_size: c.max_chunk_size, + norm_level: c.norm_level, + } + }), }, column_specific_options, key_value_metadata, @@ -674,8 +691,7 @@ mod tests { let mut default_table_writer_opts = TableParquetOptions::default(); let default_parquet_opts = ParquetOptions::default(); assert_eq!( - default_table_writer_opts.global, - default_parquet_opts, + default_table_writer_opts.global, default_parquet_opts, "should have matching defaults for TableParquetOptions.global and ParquetOptions", ); @@ -699,7 +715,9 @@ mod tests { "should have different created_by sources", ); assert!( - default_writer_props.created_by().starts_with("parquet-rs version"), + default_writer_props + .created_by() + .starts_with("parquet-rs version"), "should indicate that writer_props defaults came from the extern parquet crate", ); assert!( @@ -733,8 +751,7 @@ mod tests { from_extern_parquet.global.skip_arrow_metadata = true; assert_eq!( - default_table_writer_opts, - from_extern_parquet, + default_table_writer_opts, from_extern_parquet, "the default writer_props should have the same configuration as the session's default TableParquetOptions", ); } @@ -800,6 +817,74 @@ mod tests { ); } + #[test] + fn test_cdc_enabled_with_custom_options() { + let mut opts = TableParquetOptions::default(); + opts.global.use_content_defined_chunking = Some(CdcOptions { + min_chunk_size: 128 * 1024, + max_chunk_size: 512 * 1024, + norm_level: 2, + }); + opts.arrow_schema(&Arc::new(Schema::empty())); + + let props = WriterPropertiesBuilder::try_from(&opts).unwrap().build(); + let cdc = props.content_defined_chunking().expect("CDC should be set"); + assert_eq!(cdc.min_chunk_size, 128 * 1024); + assert_eq!(cdc.max_chunk_size, 512 * 1024); + assert_eq!(cdc.norm_level, 2); + } + + #[test] + fn test_cdc_disabled_by_default() { + let mut opts = TableParquetOptions::default(); + opts.arrow_schema(&Arc::new(Schema::empty())); + + let props = WriterPropertiesBuilder::try_from(&opts).unwrap().build(); + assert!(props.content_defined_chunking().is_none()); + } + + #[test] + fn test_cdc_round_trip_through_writer_props() { + let mut opts = TableParquetOptions::default(); + opts.global.use_content_defined_chunking = Some(CdcOptions { + min_chunk_size: 64 * 1024, + max_chunk_size: 2 * 1024 * 1024, + norm_level: -1, + }); + opts.arrow_schema(&Arc::new(Schema::empty())); + + let props = WriterPropertiesBuilder::try_from(&opts).unwrap().build(); + let recovered = session_config_from_writer_props(&props); + + let cdc = recovered.global.use_content_defined_chunking.unwrap(); + assert_eq!(cdc.min_chunk_size, 64 * 1024); + assert_eq!(cdc.max_chunk_size, 2 * 1024 * 1024); + assert_eq!(cdc.norm_level, -1); + } + + #[test] + fn test_cdc_validation_zero_min_chunk_size() { + let mut opts = TableParquetOptions::default(); + opts.global.use_content_defined_chunking = Some(CdcOptions { + min_chunk_size: 0, + ..CdcOptions::default() + }); + opts.arrow_schema(&Arc::new(Schema::empty())); + assert!(WriterPropertiesBuilder::try_from(&opts).is_err()); + } + + #[test] + fn test_cdc_validation_max_not_greater_than_min() { + let mut opts = TableParquetOptions::default(); + opts.global.use_content_defined_chunking = Some(CdcOptions { + min_chunk_size: 512 * 1024, + max_chunk_size: 256 * 1024, + ..CdcOptions::default() + }); + opts.arrow_schema(&Arc::new(Schema::empty())); + assert!(WriterPropertiesBuilder::try_from(&opts).is_err()); + } + #[test] fn test_bloom_filter_set_ndv_only() { // the TableParquetOptions::default, with only ndv set diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 764190e1189bf..a6bd42be691a9 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -176,9 +176,9 @@ impl FromStr for ExplainFormat { "tree" => Ok(ExplainFormat::Tree), "pgjson" => Ok(ExplainFormat::PostgresJSON), "graphviz" => Ok(ExplainFormat::Graphviz), - _ => { - Err(DataFusionError::Configuration(format!("Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'"))) - } + _ => Err(DataFusionError::Configuration(format!( + "Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'" + ))), } } } @@ -206,23 +206,50 @@ impl ConfigField for ExplainFormat { } } -/// Verbosity levels controlling how `EXPLAIN ANALYZE` renders metrics +/// Categorizes metrics so the display layer can choose the desired verbosity. +/// +/// The `datafusion.explain.analyze_level` configuration controls which +/// type is shown: +/// - `"dev"` (the default): all metrics are shown. +/// - `"summary"`: only metrics tagged as `Summary` are shown. +/// +/// This is orthogonal to [`MetricCategory`], which filters by *what kind* +/// of value a metric represents (rows / bytes / timing). +/// +/// # Difference from `EXPLAIN ANALYZE VERBOSE` +/// +/// The `VERBOSE` keyword controls whether per-partition metrics are shown +/// (when specified) or aggregated metrics are displayed (when omitted). +/// In contrast, `MetricType` determines which *levels* of metrics are +/// displayed. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ExplainAnalyzeLevel { - /// Show a compact view containing high-level metrics +pub enum MetricType { + /// Common metrics for high-level insights (answering which operator is slow) Summary, - /// Show a developer-focused view with per-operator details + /// For deep operator-level introspection for developers Dev, - // When adding new enum, update the error message in `from_str()` accordingly. } -impl FromStr for ExplainAnalyzeLevel { +impl MetricType { + /// Returns the set of metric types that should be shown for this level. + /// + /// `Dev` is a superset of `Summary`: when the user selects + /// `analyze_level = 'dev'`, both `Summary` and `Dev` metrics are shown. + pub fn included_types(self) -> Vec { + match self { + MetricType::Summary => vec![MetricType::Summary], + MetricType::Dev => vec![MetricType::Summary, MetricType::Dev], + } + } +} + +impl FromStr for MetricType { type Err = DataFusionError; - fn from_str(level: &str) -> Result { - match level.to_lowercase().as_str() { - "summary" => Ok(ExplainAnalyzeLevel::Summary), - "dev" => Ok(ExplainAnalyzeLevel::Dev), + fn from_str(s: &str) -> Result { + match s.trim().to_lowercase().as_str() { + "summary" => Ok(Self::Summary), + "dev" => Ok(Self::Dev), other => Err(DataFusionError::Configuration(format!( "Invalid explain analyze level. Expected 'summary' or 'dev'. Got '{other}'" ))), @@ -230,23 +257,176 @@ impl FromStr for ExplainAnalyzeLevel { } } -impl Display for ExplainAnalyzeLevel { +impl Display for MetricType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - ExplainAnalyzeLevel::Summary => "summary", - ExplainAnalyzeLevel::Dev => "dev", - }; - write!(f, "{s}") + match self { + Self::Summary => write!(f, "summary"), + Self::Dev => write!(f, "dev"), + } + } +} + +impl ConfigField for MetricType { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = MetricType::from_str(value)?; + Ok(()) + } +} + +/// Classifies a metric by what it measures. +/// +/// This is orthogonal to [`MetricType`] (Summary / Dev), which controls +/// *verbosity*. `MetricCategory` controls *what kind of value* is shown, +/// so that `EXPLAIN ANALYZE` output can be narrowed to only the categories +/// that are useful in a given context. +/// +/// In particular this is useful for testing since metrics differ in their stability across runs: +/// - [`Rows`](Self::Rows) and [`Bytes`](Self::Bytes) depend only on the plan +/// and the data, so they are mostly deterministic across runs (given the same +/// input). Variations can existing e.g. because of non-deterministic ordering +/// of evaluation between threads. +/// Running with a single target partition often makes these metrics stable enough to assert on in tests. +/// - [`Timing`](Self::Timing) depends on hardware, system load, scheduling, +/// etc., so it varies from run to run even on the same machine. +/// +/// [`MetricCategory`] is especially useful in sqllogictest (`.slt`) files: +/// setting `datafusion.explain.analyze_categories = 'rows'` lets a test +/// assert on row-count metrics without sprinkling `` over every +/// timing value. +/// +/// Metrics that do not declare a category (the default for custom +/// `Count` / `Gauge` metrics) are treated as +/// [`Uncategorized`](Self::Uncategorized) for filtering purposes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MetricCategory { + /// Row counts and related dimensionless counters: `output_rows`, + /// `spilled_rows`, `output_batches`, pruning metrics, ratios, etc. + /// + /// Mostly deterministic given the same plan and data. + Rows, + /// Byte measurements: `output_bytes`, `spilled_bytes`, + /// `current_memory_usage`, `bytes_scanned`, etc. + /// + /// Mostly deterministic given the same plan and data. + Bytes, + /// Wall-clock durations and timestamps: `elapsed_compute`, + /// operator-defined `Time` metrics, `start_timestamp` / + /// `end_timestamp`, etc. + /// + /// **Non-deterministic** — varies across runs even on the same hardware. + Timing, + /// Catch-all for metrics that do not fit into [`Rows`](Self::Rows), + /// [`Bytes`](Self::Bytes), or [`Timing`](Self::Timing). + /// + /// Custom `Count` / `Gauge` metrics that are not explicitly assigned + /// a category are treated as `Uncategorized` for filtering purposes. + /// + /// This variant lets users explicitly include or exclude these + /// metrics, e.g.: + /// ```sql + /// SET datafusion.explain.analyze_categories = 'rows, bytes, uncategorized'; + /// ``` + Uncategorized, +} + +impl FromStr for MetricCategory { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.trim().to_lowercase().as_str() { + "rows" => Ok(Self::Rows), + "bytes" => Ok(Self::Bytes), + "timing" => Ok(Self::Timing), + "uncategorized" => Ok(Self::Uncategorized), + other => Err(DataFusionError::Configuration(format!( + "Invalid metric category '{other}'. \ + Expected 'rows', 'bytes', 'timing', or 'uncategorized'." + ))), + } + } +} + +impl Display for MetricCategory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Rows => write!(f, "rows"), + Self::Bytes => write!(f, "bytes"), + Self::Timing => write!(f, "timing"), + Self::Uncategorized => write!(f, "uncategorized"), + } + } +} + +/// Controls which [`MetricCategory`] values are shown in `EXPLAIN ANALYZE`. +/// +/// Set via `SET datafusion.explain.analyze_categories = '...'`. +/// +/// See [`MetricCategory`] for the determinism properties that motivate +/// this filter. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +pub enum ExplainAnalyzeCategories { + /// Show all metrics regardless of category (the default). + #[default] + All, + /// Show only metrics whose category is in the list. + /// Metrics with no declared category are treated as + /// [`Uncategorized`](MetricCategory::Uncategorized) for filtering. + /// + /// An **empty** vec means "plan only" — suppress all metrics. + Only(Vec), +} + +impl FromStr for ExplainAnalyzeCategories { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + let s = s.trim().to_lowercase(); + match s.as_str() { + "all" => Ok(Self::All), + "none" => Ok(Self::Only(vec![])), + other => { + let mut cats = Vec::new(); + for part in other.split(',') { + cats.push(part.trim().parse::()?); + } + cats.dedup(); + Ok(Self::Only(cats)) + } + } + } +} + +impl Display for ExplainAnalyzeCategories { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::All => write!(f, "all"), + Self::Only(cats) if cats.is_empty() => write!(f, "none"), + Self::Only(cats) => { + let mut first = true; + for cat in cats { + if !first { + write!(f, ",")?; + } + first = false; + write!(f, "{cat}")?; + } + Ok(()) + } + } } } -impl ConfigField for ExplainAnalyzeLevel { +impl ConfigField for ExplainAnalyzeCategories { fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } fn set(&mut self, _: &str, value: &str) -> Result<()> { - *self = ExplainAnalyzeLevel::from_str(value)?; + *self = ExplainAnalyzeCategories::from_str(value)?; Ok(()) } } diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 63962998ad18b..24ca33c0c2c90 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -590,6 +590,53 @@ pub fn get_required_group_by_exprs_indices( .collect() } +/// Returns indices for the minimal subset of ORDER BY expressions that are +/// functionally equivalent to the original set of ORDER BY expressions. +pub fn get_required_sort_exprs_indices( + schema: &DFSchema, + sort_expr_names: &[String], +) -> Vec { + let dependencies = schema.functional_dependencies(); + let field_names = schema.field_names(); + + let mut known_field_indices = HashSet::new(); + let mut required_sort_expr_indices = Vec::new(); + + for (sort_expr_idx, sort_expr_name) in sort_expr_names.iter().enumerate() { + // If the sort expression doesn't correspond to a known schema field + // (e.g. a computed expression), we can't reason about it via functional + // dependencies, so conservatively keep it. + let Some(field_idx) = field_names + .iter() + .position(|field_name| field_name == sort_expr_name) + else { + required_sort_expr_indices.push(sort_expr_idx); + continue; + }; + + // A sort expression is removable if its value is functionally determined + // by fields that already appear earlier in the sort order: if the earlier + // fields are fixed, this one's value is fixed too, so it adds no ordering + // information. + let removable = dependencies.deps.iter().any(|dependency| { + dependency.target_indices.contains(&field_idx) + && dependency + .source_indices + .iter() + .all(|source_idx| known_field_indices.contains(source_idx)) + }); + + if removable { + continue; + } + + known_field_indices.insert(field_idx); + required_sort_expr_indices.push(sort_expr_idx); + } + + required_sort_expr_indices +} + /// Updates entries inside the `entries` vector with their corresponding /// indices inside the `proj_indices` vector. fn update_elements_with_matching_indices( diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 4b18351f708b7..fcc2e919b6cc2 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -17,25 +17,33 @@ //! Functionality used both on logical and physical plans -#[cfg(not(feature = "force_hash_collisions"))] -use std::sync::Arc; - -use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; +#[cfg(not(feature = "force_hash_collisions"))] +use arrow::compute::take; use arrow::datatypes::*; #[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use foldhash::fast::FixedState; +#[cfg(not(feature = "force_hash_collisions"))] +use itertools::Itertools; +#[cfg(not(feature = "force_hash_collisions"))] +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher}; + +/// The hash random state used throughout DataFusion for hashing. +pub type RandomState = FixedState; #[cfg(not(feature = "force_hash_collisions"))] use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, - as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_string_array, as_string_view_array, as_struct_array, + as_generic_binary_array, as_large_list_array, as_large_list_view_array, + as_list_array, as_list_view_array, as_map_array, as_string_array, + as_string_view_array, as_struct_array, as_union_array, }; use crate::error::Result; -#[cfg(not(feature = "force_hash_collisions"))] -use crate::error::_internal_err; +use crate::error::{_internal_datafusion_err, _internal_err}; +use std::cell::RefCell; // Combines two hashes into one hash #[inline] @@ -44,6 +52,94 @@ pub fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +/// Maximum size for the thread-local hash buffer before truncation (4MB = 524,288 u64 elements). +/// The goal of this is to avoid unbounded memory growth that would appear as a memory leak. +/// We allow temporary allocations beyond this size, but after use the buffer is truncated +/// to this size. +const MAX_BUFFER_SIZE: usize = 524_288; + +thread_local! { + /// Thread-local buffer for hash computations to avoid repeated allocations. + /// The buffer is reused across calls and truncated if it exceeds MAX_BUFFER_SIZE. + /// Defaults to a capacity of 8192 u64 elements which is the default batch size. + /// This corresponds to 64KB of memory. + static HASH_BUFFER: RefCell> = const { RefCell::new(Vec::new()) }; +} + +/// Creates hashes for the given arrays using a thread-local buffer, then calls the provided callback +/// with an immutable reference to the computed hashes. +/// +/// This function manages a thread-local buffer to avoid repeated allocations. The buffer is automatically +/// truncated if it exceeds `MAX_BUFFER_SIZE` after use. +/// +/// # Arguments +/// * `arrays` - The arrays to hash (must contain at least one array) +/// * `random_state` - The random state for hashing +/// * `callback` - A function that receives an immutable reference to the hash slice and returns a result +/// +/// # Errors +/// Returns an error if: +/// - No arrays are provided +/// - The function is called reentrantly (i.e., the callback invokes `with_hashes` again on the same thread) +/// - The function is called during or after thread destruction +/// +/// # Example +/// ```ignore +/// use datafusion_common::hash_utils::{with_hashes, RandomState}; +/// use arrow::array::{Int32Array, ArrayRef}; +/// use std::sync::Arc; +/// +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); +/// let random_state = RandomState::default(); +/// +/// let result = with_hashes([&array], &random_state, |hashes| { +/// // Use the hashes here +/// Ok(hashes.len()) +/// })?; +/// ``` +pub fn with_hashes( + arrays: I, + random_state: &RandomState, + callback: F, +) -> Result +where + I: IntoIterator, + T: AsDynArray, + F: FnOnce(&[u64]) -> Result, +{ + // Peek at the first array to determine buffer size without fully collecting + let mut iter = arrays.into_iter().peekable(); + + // Get the required size from the first array + let required_size = match iter.peek() { + Some(arr) => arr.as_dyn_array().len(), + None => return _internal_err!("with_hashes requires at least one array"), + }; + + HASH_BUFFER.try_with(|cell| { + let mut buffer = cell.try_borrow_mut() + .map_err(|_| _internal_datafusion_err!("with_hashes cannot be called reentrantly on the same thread"))?; + + // Ensure buffer has sufficient length, clearing old values + buffer.clear(); + buffer.resize(required_size, 0); + + // Create hashes in the buffer - this consumes the iterator + create_hashes(iter, random_state, &mut buffer[..required_size])?; + + // Execute the callback with an immutable slice + let result = callback(&buffer[..required_size])?; + + // Cleanup: truncate if buffer grew too large + if buffer.capacity() > MAX_BUFFER_SIZE { + buffer.truncate(MAX_BUFFER_SIZE); + buffer.shrink_to_fit(); + } + + Ok(result) + }).map_err(|_| _internal_datafusion_err!("with_hashes cannot access thread-local storage during or after thread destruction"))? +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { @@ -60,12 +156,17 @@ fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: pub trait HashValue { fn hash_one(&self, state: &RandomState) -> u64; + /// Write this value into an existing hasher (same data as `hash_one`). + fn hash_write(&self, hasher: &mut impl Hasher); } impl HashValue for &T { fn hash_one(&self, state: &RandomState) -> u64 { T::hash_one(self, state) } + fn hash_write(&self, hasher: &mut impl Hasher) { + T::hash_write(self, hasher) + } } macro_rules! hash_value { @@ -74,10 +175,13 @@ macro_rules! hash_value { fn hash_one(&self, state: &RandomState) -> u64 { state.hash_one(self) } + fn hash_write(&self, hasher: &mut impl Hasher) { + Hash::hash(self, hasher) + } })+ }; } -hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); +hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, u128); hash_value!(bool, str, [u8], IntervalDayTime, IntervalMonthDayNano); macro_rules! hash_float_value { @@ -86,14 +190,29 @@ macro_rules! hash_float_value { fn hash_one(&self, state: &RandomState) -> u64 { state.hash_one(<$i>::from_ne_bytes(self.to_ne_bytes())) } + fn hash_write(&self, hasher: &mut impl Hasher) { + hasher.write(&self.to_ne_bytes()) + } })+ }; } hash_float_value!((half::f16, u16), (f32, u32), (f64, u64)); +/// Create a `SeedableRandomState` whose per-hasher seed incorporates `seed`. +/// This folds the previous hash into the hasher's initial state so only the +/// new value needs to pass through the hash function — same cost as `hash_one`. +#[cfg(not(feature = "force_hash_collisions"))] +#[inline] +fn seeded_state(seed: u64) -> foldhash::fast::SeedableRandomState { + foldhash::fast::SeedableRandomState::with_seed( + seed, + foldhash::SharedSeed::global_fixed(), + ) +} + /// Builds hash values of PrimitiveArray and writes them into `hashes_buffer` -/// If `rehash==true` this combines the previous hash value in the buffer -/// with the new hash using `combine_hashes` +/// If `rehash==true` this folds the existing hash into the hasher state +/// and hashes only the new value (avoiding a separate combine step). #[cfg(not(feature = "force_hash_collisions"))] fn hash_array_primitive( array: &PrimitiveArray, @@ -112,7 +231,9 @@ fn hash_array_primitive( if array.null_count() == 0 { if rehash { for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { - *hash = combine_hashes(value.hash_one(random_state), *hash); + let mut hasher = seeded_state(*hash).build_hasher(); + value.hash_write(&mut hasher); + *hash = hasher.finish(); } } else { for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { @@ -120,18 +241,16 @@ fn hash_array_primitive( } } } else if rehash { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = combine_hashes(value.hash_one(random_state), *hash); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + let mut hasher = seeded_state(hashes_buffer[i]).build_hasher(); + value.hash_write(&mut hasher); + hashes_buffer[i] = hasher.finish(); } } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = value.hash_one(random_state); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = value.hash_one(random_state); } } } @@ -141,7 +260,7 @@ fn hash_array_primitive( /// with the new hash using `combine_hashes` #[cfg(not(feature = "force_hash_collisions"))] fn hash_array( - array: T, + array: &T, random_state: &RandomState, hashes_buffer: &mut [u64], rehash: bool, @@ -168,74 +287,257 @@ fn hash_array( } } } else if rehash { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = combine_hashes(value.hash_one(random_state), *hash); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = + combine_hashes(value.hash_one(random_state), hashes_buffer[i]); } } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = value.hash_one(random_state); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = value.hash_one(random_state); } } } -/// Helper function to update hash for a dictionary key if the value is valid +/// Hash a StringView or BytesView array +/// +/// Templated to optimize inner loop based on presence of nulls and external buffers. +/// +/// HAS_NULLS: do we have to check null in the inner loop +/// HAS_BUFFERS: if true, array has external buffers; if false, all strings are inlined/ less then 12 bytes +/// REHASH: if true, combining with existing hash, otherwise initializing #[cfg(not(feature = "force_hash_collisions"))] -#[inline] -fn update_hash_for_dict_key( - hash: &mut u64, - dict_hashes: &[u64], - dict_values: &dyn Array, - idx: usize, - multi_col: bool, +#[inline(never)] +fn hash_string_view_array_inner< + T: ByteViewType, + const HAS_NULLS: bool, + const HAS_BUFFERS: bool, + const REHASH: bool, +>( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], ) { - if dict_values.is_valid(idx) { - if multi_col { - *hash = combine_hashes(dict_hashes[idx], *hash); + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + + let buffers = array.data_buffers(); + let view_bytes = |view_len: u32, view: u128| { + let view = ByteView::from(view); + let offset = view.offset as usize; + // SAFETY: view is a valid view as it came from the array + unsafe { + let data = buffers.get_unchecked(view.buffer_index as usize); + data.get_unchecked(offset..offset + view_len as usize) + } + }; + + let hashes_and_views = hashes_buffer.iter_mut().zip(array.views().iter()); + for (i, (hash, &v)) in hashes_and_views.enumerate() { + if HAS_NULLS && array.is_null(i) { + continue; + } + let view_len = v as u32; + // all views are inlined, no need to access external buffers + if !HAS_BUFFERS || view_len <= 12 { + if REHASH { + let mut hasher = seeded_state(*hash).build_hasher(); + v.hash_write(&mut hasher); + *hash = hasher.finish(); + } else { + *hash = v.hash_one(random_state); + } + continue; + } + // view is not inlined, so we need to hash the bytes as well + let value = view_bytes(view_len, v); + if REHASH { + let mut hasher = seeded_state(*hash).build_hasher(); + value.hash_write(&mut hasher); + *hash = hasher.finish(); } else { - *hash = dict_hashes[idx]; + *hash = value.hash_one(random_state); } } - // no update for invalid dictionary value } -/// Hash the values in a dictionary array +/// Builds hash values for array views and writes them into `hashes_buffer` +/// If `rehash==true` this combines the previous hash value in the buffer +/// with the new hash using `combine_hashes` #[cfg(not(feature = "force_hash_collisions"))] -fn hash_dictionary( +fn hash_generic_byte_view_array( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) { + // instantiate the correct version based on presence of nulls and external buffers + match ( + array.null_count() != 0, + !array.data_buffers().is_empty(), + rehash, + ) { + // no nulls or buffers ==> hash the inlined views directly + // don't call the inner function as Rust seems better able to inline this simpler code (2-3% faster) + (false, false, false) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + *hash = view.hash_one(random_state); + } + } + (false, false, true) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + let mut hasher = seeded_state(*hash).build_hasher(); + view.hash_write(&mut hasher); + *hash = hasher.finish(); + } + } + (false, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + +/// Hash dictionary array with compile-time specialization for null handling. +/// +/// Uses const generics to eliminate runtim branching in the hot loop: +/// - `HAS_NULL_KEYS`: Whether to check for null dictionary keys +/// - `HAS_NULL_VALUES`: Whether to check for null dictionary values +/// - `MULTI_COL`: Whether to combine with existing hash (true) or initialize (false) +#[cfg(not(feature = "force_hash_collisions"))] +#[inline(never)] +fn hash_dictionary_inner< + K: ArrowDictionaryKeyType, + const HAS_NULL_KEYS: bool, + const HAS_NULL_VALUES: bool, + const MULTI_COL: bool, +>( array: &DictionaryArray, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, ) -> Result<()> { // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(array.values()); - let mut dict_hashes = vec![0; dict_values.len()]; - create_hashes(&[dict_values], random_state, &mut dict_hashes)?; - - // combine hash for each index in values let dict_values = array.values(); - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { + let mut dict_hashes = vec![0; dict_values.len()]; + create_hashes([dict_values], random_state, &mut dict_hashes)?; + + if HAS_NULL_KEYS { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { + if let Some(key) = key { + let idx = key.as_usize(); + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } + } + } else { + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().values()) { let idx = key.as_usize(); - update_hash_for_dict_key( - hash, - &dict_hashes, - dict_values.as_ref(), - idx, - multi_col, - ); - } // no update for Null key + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } } Ok(()) } +/// Hash the values in a dictionary array +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_dictionary( + array: &DictionaryArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + multi_col: bool, +) -> Result<()> { + let has_null_keys = array.keys().null_count() != 0; + let has_null_values = array.values().null_count() != 0; + + // Dispatcher based on null presence and multi-column mode + // Should reduce branching within hot loops + match (has_null_keys, has_null_values, multi_col) { + (false, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_struct_array( array: &StructArray, @@ -245,19 +547,21 @@ fn hash_struct_array( let nulls = array.nulls(); let row_len = array.len(); - let valid_row_indices: Vec = if let Some(nulls) = nulls { - nulls.valid_indices().collect() - } else { - (0..row_len).collect() - }; - // Create hashes for each row that combines the hashes over all the column at that row. let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); + // Separate paths to avoid allocating Vec when there are no nulls + if let Some(nulls) = nulls { + for i in nulls.valid_indices() { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + } else { + for i in 0..row_len { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } } Ok(()) @@ -274,15 +578,29 @@ fn hash_map_array( let offsets = array.offsets(); // Create hashes for each entry in each row - let mut values_hashes = vec![0u64; array.entries().len()]; - create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + let first_offset = offsets.first().copied().unwrap_or_default() as usize; + let last_offset = offsets.last().copied().unwrap_or_default() as usize; + let entries_len = last_offset - first_offset; + + // Only hash the entries that are actually referenced + let mut values_hashes = vec![0u64; entries_len]; + let entries = array.entries(); + let sliced_columns: Vec = entries + .columns() + .iter() + .map(|col| col.slice(first_offset, entries_len)) + .collect(); + create_hashes(&sliced_columns, random_state, &mut values_hashes)?; // Combine the hashes for entries on each row with each other and previous hash for that row + // Adjust indices by first_offset since values_hashes is sliced starting from first_offset if let Some(nulls) = nulls { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -290,7 +608,9 @@ fn hash_map_array( } else { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -308,24 +628,80 @@ fn hash_list_array( where OffsetSize: OffsetSizeTrait, { - let values = Arc::clone(array.values()); + // In case values is sliced, hash only the bytes used by the offsets of this ListArray + let first_offset = array.value_offsets().first().cloned().unwrap_or_default(); + let last_offset = array.value_offsets().last().cloned().unwrap_or_default(); + let value_bytes_len = (last_offset - first_offset).as_usize(); + let mut values_hashes = vec![0u64; value_bytes_len]; + create_hashes( + [array + .values() + .slice(first_offset.as_usize(), value_bytes_len)], + random_state, + &mut values_hashes, + )?; + + if array.null_count() > 0 { + for (i, (start, stop)) in array.value_offsets().iter().tuple_windows().enumerate() + { + if array.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for ((start, stop), hash) in array + .value_offsets() + .iter() + .tuple_windows() + .zip(hashes_buffer.iter_mut()) + { + for values_hash in &values_hashes + [(*start - first_offset).as_usize()..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_list_view_array( + array: &GenericListViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> +where + OffsetSize: OffsetSizeTrait, +{ + let values = array.values(); let offsets = array.value_offsets(); + let sizes = array.value_sizes(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; - create_hashes(&[values], random_state, &mut values_hashes)?; + create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } } } else { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } @@ -333,17 +709,145 @@ where Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array( + array: &UnionArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let DataType::Union(union_fields, _mode) = array.data_type() else { + unreachable!() + }; + + if array.is_dense() { + // Dense union: children only contain values of their type, so they're already compact. + // Use the default hashing approach which is efficient for dense unions. + hash_union_array_default(array, union_fields, random_state, hashes_buffer) + } else { + // Sparse union: each child has the same length as the union array. + // Optimization: only hash the elements that are actually referenced by type_ids, + // instead of hashing all K*N elements (where K = num types, N = array length). + hash_sparse_union_array(array, union_fields, random_state, hashes_buffer) + } +} + +/// Default hashing for union arrays - hashes all elements of each child array fully. +/// +/// This approach works for both dense and sparse union arrays: +/// - Dense unions: children are compact (each child only contains values of that type) +/// - Sparse unions: children have the same length as the union array +/// +/// For sparse unions with 3+ types, the optimized take/scatter approach in +/// `hash_sparse_union_array` is more efficient, but for 1-2 types or dense unions, +/// this simpler approach is preferred. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array_default( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let mut child_hashes: HashMap> = + HashMap::with_capacity(union_fields.len()); + + // Hash each child array fully + for (type_id, _field) in union_fields.iter() { + let child = array.child(type_id); + let mut child_hash_buffer = vec![0; child.len()]; + create_hashes([child], random_state, &mut child_hash_buffer)?; + + child_hashes.insert(type_id, child_hash_buffer); + } + + // Combine hashes for each row using the appropriate child offset + // For dense unions: value_offset points to the actual position in the child + // For sparse unions: value_offset equals the row index + #[expect(clippy::needless_range_loop)] + for i in 0..array.len() { + let type_id = array.type_id(i); + let child_offset = array.value_offset(i); + + let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); + hashes_buffer[i] = combine_hashes(hashes_buffer[i], child_hash[child_offset]); + } + + Ok(()) +} + +/// Hash a sparse union array. +/// Sparse unions have child arrays with the same length as the union array. +/// For 3+ types, we optimize by only hashing the N elements that are actually used +/// (via take/scatter), instead of hashing all K*N elements. +/// +/// For 1-2 types, the overhead of take/scatter outweighs the benefit, so we use +/// the default approach of hashing all children (same as dense unions). +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_sparse_union_array( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + use std::collections::HashMap; + + // For 1-2 types, the take/scatter overhead isn't worth it. + // Fall back to the default approach (same as dense union). + if union_fields.len() <= 2 { + return hash_union_array_default( + array, + union_fields, + random_state, + hashes_buffer, + ); + } + + let type_ids = array.type_ids(); + + // Group indices by type_id + let mut indices_by_type: HashMap> = HashMap::new(); + for (i, &type_id) in type_ids.iter().enumerate() { + indices_by_type.entry(type_id).or_default().push(i as u32); + } + + // For each type, extract only the needed elements, hash them, and scatter back + for (type_id, _field) in union_fields.iter() { + if let Some(indices) = indices_by_type.get(&type_id) { + if indices.is_empty() { + continue; + } + + let child = array.child(type_id); + let indices_array = UInt32Array::from(indices.clone()); + + // Extract only the elements we need using take() + let filtered = take(child.as_ref(), &indices_array, None)?; + + // Hash the filtered array + let mut filtered_hashes = vec![0u64; filtered.len()]; + create_hashes([&filtered], random_state, &mut filtered_hashes)?; + + // Scatter hashes back to correct positions + for (hash, &idx) in filtered_hashes.iter().zip(indices.iter()) { + hashes_buffer[idx as usize] = + combine_hashes(hashes_buffer[idx as usize], *hash); + } + } + } + + Ok(()) +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - let values = Arc::clone(array.values()); + let values = array.values(); let value_length = array.value_length() as usize; let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; - create_hashes(&[values], random_state, &mut values_hashes)?; + create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { for i in 0..array.len() { if nulls.is_valid(i) { @@ -366,83 +870,246 @@ fn hash_fixed_list_array( Ok(()) } -/// Test version of `create_hashes` that produces the same value for -/// all hashes (to test collisions) -/// -/// See comments on `hashes_buffer` for more details +/// Inner hash function for RunArray +#[inline(never)] +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array_inner< + R: RunEndIndexType, + const HAS_NULL_VALUES: bool, + const REHASH: bool, +>( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + // We find the relevant runs that cover potentially sliced arrays, so we can only hash those + // values. Then we find the runs that refer to the original runs and ensure that we apply + // hashes correctly to the sliced, whether sliced at the start, end, or both. + let array_offset = array.offset(); + let array_len = array.len(); + + if array_len == 0 { + return Ok(()); + } + + let run_ends = array.run_ends(); + let run_ends_values = run_ends.values(); + let values = array.values(); + + let start_physical_index = array.get_start_physical_index(); + // get_end_physical_index returns the inclusive last index, but we need the exclusive range end + // for the operations we use below. + let end_physical_index = array.get_end_physical_index() + 1; + + let sliced_values = values.slice( + start_physical_index, + end_physical_index - start_physical_index, + ); + let mut values_hashes = vec![0u64; sliced_values.len()]; + create_hashes( + std::slice::from_ref(&sliced_values), + random_state, + &mut values_hashes, + )?; + + let mut start_in_slice = 0; + for (adjusted_physical_index, &absolute_run_end) in run_ends_values + [start_physical_index..end_physical_index] + .iter() + .enumerate() + { + let absolute_run_end = absolute_run_end.as_usize(); + let end_in_slice = (absolute_run_end - array_offset).min(array_len); + + if HAS_NULL_VALUES && sliced_values.is_null(adjusted_physical_index) { + start_in_slice = end_in_slice; + continue; + } + + let value_hash = values_hashes[adjusted_physical_index]; + let run_slice = &mut hashes_buffer[start_in_slice..end_in_slice]; + + if REHASH { + for hash in run_slice.iter_mut() { + *hash = combine_hashes(value_hash, *hash); + } + } else { + run_slice.fill(value_hash); + } + + start_in_slice = end_in_slice; + } + + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + let has_null_values = array.values().null_count() != 0; + + match (has_null_values, rehash) { + (false, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (false, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + } +} + +/// Internal helper function that hashes a single array and either initializes or combines +/// the hash values in the buffer. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_single_array( + array: &dyn Array, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + downcast_primitive_array! { + array => hash_array_primitive(array, random_state, hashes_buffer, rehash), + DataType::Null => hash_null(random_state, hashes_buffer, rehash), + DataType::Boolean => hash_array(&as_boolean_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8 => hash_array(&as_string_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8View => hash_generic_byte_view_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::LargeUtf8 => hash_array(&as_largestring_array(array), random_state, hashes_buffer, rehash), + DataType::Binary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::BinaryView => hash_generic_byte_view_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::LargeBinary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::FixedSizeBinary(_) => { + let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); + hash_array(&array, random_state, hashes_buffer, rehash) + } + DataType::Dictionary(_, _) => downcast_dictionary_array! { + array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, + _ => unreachable!() + } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } + DataType::List(_) => { + let array = as_list_array(array)?; + hash_list_array(array, random_state, hashes_buffer)?; + } + DataType::LargeList(_) => { + let array = as_large_list_array(array)?; + hash_list_array(array, random_state, hashes_buffer)?; + } + DataType::ListView(_) => { + let array = as_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::LargeListView(_) => { + let array = as_large_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::Map(_, _) => { + let array = as_map_array(array)?; + hash_map_array(array, random_state, hashes_buffer)?; + } + DataType::FixedSizeList(_,_) => { + let array = as_fixed_size_list_array(array)?; + hash_fixed_list_array(array, random_state, hashes_buffer)?; + } + DataType::Union(_, _) => { + let array = as_union_array(array)?; + hash_union_array(array, random_state, hashes_buffer)?; + } + DataType::RunEndEncoded(_, _) => downcast_run_array! { + array => hash_run_array(array, random_state, hashes_buffer, rehash)?, + _ => unreachable!() + } + _ => { + // This is internal because we should have caught this before. + return _internal_err!( + "Unsupported data type in hasher: {}", + array.data_type() + ); + } + } + Ok(()) +} + +/// Test version of `hash_single_array` that forces all hashes to collide to zero. #[cfg(feature = "force_hash_collisions")] -pub fn create_hashes<'a>( - _arrays: &[ArrayRef], +fn hash_single_array( + _array: &dyn Array, _random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { + hashes_buffer: &mut [u64], + _rehash: bool, +) -> Result<()> { for hash in hashes_buffer.iter_mut() { *hash = 0 } - Ok(hashes_buffer) + Ok(()) } -/// Creates hash values for every row, based on the values in the -/// columns. +/// Something that can be returned as a `&dyn Array`. +/// +/// We want `create_hashes` to accept either `&dyn Array` or `ArrayRef`, +/// and this seems the best way to do so. +/// +/// We tried having it accept `AsRef` +/// but that is not implemented for and cannot be implemented for +/// `&dyn Array` so callers that have the latter would not be able +/// to call `create_hashes` directly. This shim trait makes it possible. +pub trait AsDynArray { + fn as_dyn_array(&self) -> &dyn Array; +} + +impl AsDynArray for dyn Array { + fn as_dyn_array(&self) -> &dyn Array { + self + } +} + +impl AsDynArray for &dyn Array { + fn as_dyn_array(&self) -> &dyn Array { + *self + } +} + +impl AsDynArray for ArrayRef { + fn as_dyn_array(&self) -> &dyn Array { + self.as_ref() + } +} + +impl AsDynArray for &ArrayRef { + fn as_dyn_array(&self) -> &dyn Array { + self.as_ref() + } +} + +/// Creates hash values for every row, based on the values in the columns. /// /// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -#[cfg(not(feature = "force_hash_collisions"))] -pub fn create_hashes<'a>( - arrays: &[ArrayRef], +/// `hashes_buffer` should be pre-sized appropriately. +pub fn create_hashes<'a, I, T>( + arrays: I, random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for (i, col) in arrays.iter().enumerate() { - let array = col.as_ref(); + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> +where + I: IntoIterator, + T: AsDynArray, +{ + for (i, array) in arrays.into_iter().enumerate() { // combine hashes with `combine_hashes` for all columns besides the first let rehash = i >= 1; - downcast_primitive_array! { - array => hash_array_primitive(array, random_state, hashes_buffer, rehash), - DataType::Null => hash_null(random_state, hashes_buffer, rehash), - DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash), - DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash), - DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), - DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash), - DataType::Binary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), - DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), - DataType::LargeBinary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), - DataType::FixedSizeBinary(_) => { - let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); - hash_array(array, random_state, hashes_buffer, rehash) - } - DataType::Dictionary(_, _) => downcast_dictionary_array! { - array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, - _ => unreachable!() - } - DataType::Struct(_) => { - let array = as_struct_array(array)?; - hash_struct_array(array, random_state, hashes_buffer)?; - } - DataType::List(_) => { - let array = as_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; - } - DataType::LargeList(_) => { - let array = as_large_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; - } - DataType::Map(_, _) => { - let array = as_map_array(array)?; - hash_map_array(array, random_state, hashes_buffer)?; - } - DataType::FixedSizeList(_,_) => { - let array = as_fixed_size_list_array(array)?; - hash_fixed_list_array(array, random_state, hashes_buffer)?; - } - _ => { - // This is internal because we should have caught this before. - return _internal_err!( - "Unsupported data type in hasher: {}", - col.data_type() - ); - } - } + hash_single_array(array.as_dyn_array(), random_state, hashes_buffer, rehash)?; } Ok(hashes_buffer) } @@ -465,8 +1132,8 @@ mod tests { .collect::() .with_precision_and_scale(20, 3) .unwrap(); - let array_ref = Arc::new(array); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let array_ref: ArrayRef = Arc::new(array); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; array_ref.len()]; let hashes = create_hashes(&[array_ref], &random_state, hashes_buff)?; assert_eq!(hashes.len(), 4); @@ -476,19 +1143,25 @@ mod tests { #[test] fn create_hashes_for_empty_fixed_size_lit() -> Result<()> { let empty_array = FixedSizeListBuilder::new(StringBuilder::new(), 1).finish(); - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; 0]; - let hashes = create_hashes(&[Arc::new(empty_array)], &random_state, hashes_buff)?; + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut [0; 0]; + let hashes = create_hashes( + &[Arc::new(empty_array) as ArrayRef], + &random_state, + hashes_buff, + )?; assert_eq!(hashes, &Vec::::new()); Ok(()) } #[test] fn create_hashes_for_float_arrays() -> Result<()> { - let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); - let f64_arr = Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); + let f32_arr: ArrayRef = + Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); + let f64_arr: ArrayRef = + Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; f32_arr.len()]; let hashes = create_hashes(&[f32_arr], &random_state, hashes_buff)?; assert_eq!(hashes.len(), 4,); @@ -514,18 +1187,15 @@ mod tests { Some(b"Longer than 12 bytes string"), ]; - let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>()); - let ref_array = Arc::new(binary.iter().cloned().collect::()); + let binary_array: ArrayRef = + Arc::new(binary.iter().cloned().collect::<$ARRAY>()); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut binary_hashes = vec![0; binary.len()]; create_hashes(&[binary_array], &random_state, &mut binary_hashes) .unwrap(); - let mut ref_hashes = vec![0; binary.len()]; - create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap(); - // Null values result in a zero hash, for (val, hash) in binary.iter().zip(binary_hashes.iter()) { match val { @@ -534,9 +1204,6 @@ mod tests { } } - // same logical values should hash to the same hash value - assert_eq!(binary_hashes, ref_hashes); - // Same values should map to same hash values assert_eq!(binary[0], binary[5]); assert_eq!(binary[4], binary[6]); @@ -548,15 +1215,16 @@ mod tests { } create_hash_binary!(binary_array, BinaryArray); + create_hash_binary!(large_binary_array, LargeBinaryArray); create_hash_binary!(binary_view_array, BinaryViewArray); #[test] fn create_hashes_fixed_size_binary() -> Result<()> { let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]]; - let fixed_size_binary_array = + let fixed_size_binary_array: ArrayRef = Arc::new(FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap()); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; fixed_size_binary_array.len()]; let hashes = create_hashes(&[fixed_size_binary_array], &random_state, hashes_buff)?; @@ -580,15 +1248,16 @@ mod tests { Some("Longer than 12 bytes string"), ]; - let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings.iter().cloned().collect::<$ARRAY>()); + let dict_array: ArrayRef = Arc::new( strings .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut string_hashes = vec![0; strings.len()]; create_hashes(&[string_array], &random_state, &mut string_hashes) @@ -623,21 +1292,90 @@ mod tests { create_hash_string!(string_view_array, StringArray); create_hash_string!(dict_string_array, DictionaryArray); + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; array.len()]; + let hashes = create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + hashes_buff, + )?; + + assert_eq!(hashes.len(), 7); + assert_eq!(hashes[0], hashes[1]); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[3], hashes[4]); + assert_eq!(hashes[5], hashes[6]); + assert_ne!(hashes[0], hashes[2]); + assert_ne!(hashes[2], hashes[5]); + assert_ne!(hashes[0], hashes[5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_multi_column_hash_with_run_array() -> Result<()> { + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])); + let values = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let run_array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut one_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[Arc::clone(&int_array) as ArrayRef], + &random_state, + &mut one_col_hashes, + )?; + + let mut two_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[ + Arc::clone(&int_array) as ArrayRef, + Arc::clone(&run_array) as ArrayRef, + ], + &random_state, + &mut two_col_hashes, + )?; + + assert_eq!(one_col_hashes.len(), 7); + assert_eq!(two_col_hashes.len(), 7); + assert_ne!(one_col_hashes, two_col_hashes); + + let diff_0_vs_1_one_col = one_col_hashes[0] != one_col_hashes[1]; + let diff_0_vs_1_two_col = two_col_hashes[0] != two_col_hashes[1]; + assert_eq!(diff_0_vs_1_one_col, diff_0_vs_1_two_col); + + let diff_2_vs_3_one_col = one_col_hashes[2] != one_col_hashes[3]; + let diff_2_vs_3_two_col = two_col_hashes[2] != two_col_hashes[3]; + assert_eq!(diff_2_vs_3_one_col, diff_2_vs_3_two_col); + + Ok(()) + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] fn create_hashes_for_dict_arrays() { let strings = [Some("foo"), None, Some("bar"), Some("foo"), None]; - let string_array = Arc::new(strings.iter().cloned().collect::()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings.iter().cloned().collect::()); + let dict_array: ArrayRef = Arc::new( strings .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut string_hashes = vec![0; strings.len()]; create_hashes(&[string_array], &random_state, &mut string_hashes).unwrap(); @@ -682,7 +1420,7 @@ mod tests { ]; let list_array = Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; list_array.len()]; create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[5]); @@ -691,6 +1429,130 @@ mod tests { assert_eq!(hashes[1], hashes[6]); // null vs empty list } + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Slice from here + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + // To here + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), + ]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let list_array = list_array.slice(2, 3); + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[1], hashes[2]); + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create ListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i32, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i32, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let list_view_array = + Arc::new(ListViewArray::new(field, offsets, sizes, values, nulls)) + as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; list_view_array.len()]; + create_hashes(&[list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_large_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create LargeListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i64, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i64, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let large_list_view_array = Arc::new(LargeListViewArray::new( + field, offsets, sizes, values, nulls, + )) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; large_list_view_array.len()]; + create_hashes(&[large_list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] @@ -707,7 +1569,7 @@ mod tests { Arc::new(FixedSizeListArray::from_iter_primitive::( data, 3, )) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; list_array.len()]; create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[5]); @@ -757,7 +1619,7 @@ mod tests { let array = Arc::new(struct_array) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); @@ -794,7 +1656,7 @@ mod tests { assert!(struct_array.is_valid(1)); let array = Arc::new(struct_array) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); @@ -847,7 +1709,7 @@ mod tests { let array = Arc::new(builder.finish()) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); // same value @@ -865,15 +1727,16 @@ mod tests { let strings1 = [Some("foo"), None, Some("bar")]; let strings2 = [Some("blarg"), Some("blah"), None]; - let string_array = Arc::new(strings1.iter().cloned().collect::()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings1.iter().cloned().collect::()); + let dict_array: ArrayRef = Arc::new( strings2 .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut one_col_hashes = vec![0; strings1.len()]; create_hashes( @@ -896,4 +1759,345 @@ mod tests { assert_ne!(one_col_hashes, two_col_hashes); } + + #[test] + fn test_create_hashes_from_arrays() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let float_array: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; int_array.len()]; + let hashes = + create_hashes(&[int_array, float_array], &random_state, hashes_buff).unwrap(); + assert_eq!(hashes.len(), 4,); + } + + #[test] + fn test_create_hashes_from_dyn_arrays() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let float_array: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + + // Verify that we can call create_hashes with only &dyn Array + fn test(arr1: &dyn Array, arr2: &dyn Array) { + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; arr1.len()]; + let hashes = create_hashes([arr1, arr2], &random_state, hashes_buff).unwrap(); + assert_eq!(hashes.len(), 4,); + } + test(&*int_array, &*float_array); + } + + #[test] + fn test_create_hashes_equivalence() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seed(0); + + let mut hashes1 = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut hashes1, + ) + .unwrap(); + + let mut hashes2 = vec![0; array.len()]; + create_hashes([array], &random_state, &mut hashes2).unwrap(); + + assert_eq!(hashes1, hashes2); + } + + #[test] + fn test_with_hashes() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seed(0); + + // Test that with_hashes produces the same results as create_hashes + let mut expected_hashes = vec![0; array.len()]; + create_hashes([&array], &random_state, &mut expected_hashes).unwrap(); + + let result = with_hashes([&array], &random_state, |hashes| { + assert_eq!(hashes.len(), 4); + // Verify hashes match expected values + assert_eq!(hashes, &expected_hashes[..]); + // Return a copy of the hashes + Ok(hashes.to_vec()) + }) + .unwrap(); + + // Verify callback result is returned correctly + assert_eq!(result, expected_hashes); + } + + #[test] + fn test_with_hashes_multi_column() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let str_array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let random_state = RandomState::with_seed(0); + + // Test multi-column hashing + let mut expected_hashes = vec![0; int_array.len()]; + create_hashes( + [&int_array, &str_array], + &random_state, + &mut expected_hashes, + ) + .unwrap(); + + with_hashes([&int_array, &str_array], &random_state, |hashes| { + assert_eq!(hashes.len(), 3); + assert_eq!(hashes, &expected_hashes[..]); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_with_hashes_empty_arrays() { + let random_state = RandomState::with_seed(0); + + // Test that passing no arrays returns an error + let empty: [&ArrayRef; 0] = []; + let result = with_hashes(empty, &random_state, |_hashes| Ok(())); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("requires at least one array") + ); + } + + #[test] + fn test_with_hashes_reentrancy() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array2: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); + let random_state = RandomState::with_seed(0); + + // Test that reentrant calls return an error instead of panicking + let result = with_hashes([&array], &random_state, |_hashes| { + // Try to call with_hashes again inside the callback + with_hashes([&array2], &random_state, |_inner_hashes| Ok(())) + }); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("reentrantly") || err_msg.contains("cannot be called"), + "Error message should mention reentrancy: {err_msg}", + ); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays() { + // logical array: [int(5), str("foo"), int(10), int(5)] + let int_array = Int32Array::from(vec![Some(5), None, Some(10), Some(5)]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0_i8, 1, 0, 0].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // Rows 0 and 3 both have type_id=0 (int) with value 5 + assert_eq!(hashes[0], hashes[3]); + // Row 0 (int 5) vs Row 2 (int 10) - different values + assert_ne!(hashes[0], hashes[2]); + // Row 0 (int) vs Row 1 (string) - different types + assert_ne!(hashes[0], hashes[1]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays_with_nulls() { + // logical array: [int(5), str("foo"), int(null), str(null)] + let int_array = Int32Array::from(vec![Some(5), None, None, None]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0, 1, 0, 1].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // row 2 (int null) and row 3 (str null) should have the same hash + // because they are both null values + assert_eq!(hashes[2], hashes[3]); + + // row 0 (int 5) vs row 2 (int null) - different (value vs null) + assert_ne!(hashes[0], hashes[2]); + + // row 1 (str "foo") vs row 3 (str null) - different (value vs null) + assert_ne!(hashes[1], hashes[3]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_dense_union_arrays() { + // creates a dense union array with int and string types + // [67, "norm", 100, "macdonald", 67] + let int_array = Int32Array::from(vec![67, 100, 67]); + let str_array = StringArray::from(vec!["norm", "macdonald"]); + + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, false))), + (1, Arc::new(Field::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // 67 vs "norm" + assert_ne!(hashes[0], hashes[1]); + // 67 vs 100 + assert_ne!(hashes[0], hashes[2]); + // "norm" vs "macdonald" + assert_ne!(hashes[1], hashes[3]); + // 100 vs "macdonald" + assert_ne!(hashes[2], hashes[3]); + // 67 vs 67 + assert_eq!(hashes[0], hashes[4]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut full_hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut full_hashes, + )?; + + let array_ref: ArrayRef = Arc::clone(&array) as ArrayRef; + let sliced_array = array_ref.slice(2, 3); + + let mut sliced_hashes = vec![0; sliced_array.len()]; + create_hashes( + std::slice::from_ref(&sliced_array), + &random_state, + &mut sliced_hashes, + )?; + + assert_eq!(sliced_hashes.len(), 3); + assert_eq!(sliced_hashes[0], sliced_hashes[1]); + assert_eq!(sliced_hashes[1], sliced_hashes[2]); + assert_eq!(&sliced_hashes, &full_hashes[2..5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![2, 4, 6])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut hashes, + )?; + + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[0], 0); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[2], 0); + assert_eq!(hashes[4], hashes[5]); + assert_ne!(hashes[4], 0); + assert_ne!(hashes[0], hashes[4]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls_multicolumn() -> Result<()> { + let primitive_array = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![1, 2, 3])); + let run_array = + Arc::new(RunArray::try_new(&run_ends, run_values.as_ref()).unwrap()); + let second_col = Arc::new(Int32Array::from(vec![100, 200, 300])); + + let random_state = RandomState::with_seed(0); + + let mut primitive_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&primitive_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut primitive_hashes, + )?; + + let mut run_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&run_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut run_hashes, + )?; + + assert_eq!(primitive_hashes, run_hashes); + + Ok(()) + } } diff --git a/datafusion/common/src/instant.rs b/datafusion/common/src/instant.rs index 42f21c061c0c2..a5dfb28292581 100644 --- a/datafusion/common/src/instant.rs +++ b/datafusion/common/src/instant.rs @@ -22,7 +22,7 @@ /// under `wasm` feature gate. It provides the same API as [`std::time::Instant`]. pub type Instant = web_time::Instant; -#[allow(clippy::disallowed_types)] +#[expect(clippy::disallowed_types)] #[cfg(not(target_family = "wasm"))] /// DataFusion wrapper around [`std::time::Instant`]. This is only a type alias. pub type Instant = std::time::Instant; diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index e6a90db2dc3eb..d517844db48b4 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -97,6 +97,35 @@ impl JoinType { } } + /// Whether each side of the join is preserved for ON-clause filter pushdown. + /// + /// It is only correct to push ON-clause filters below a join for preserved + /// inputs. + /// + /// # "Preserved" input definition + /// + /// A join side is preserved if the join returns all or a subset of the rows + /// from that side, such that each output row directly maps to an input row. + /// If a side is not preserved, the join can produce extra null rows that + /// don't map to any input row. + /// + /// # Return Value + /// + /// A tuple of booleans - (left_preserved, right_preserved). + pub fn on_lr_is_preserved(&self) -> (bool, bool) { + match self { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (false, false), + JoinType::LeftSemi | JoinType::RightSemi => (true, true), + JoinType::LeftAnti => (false, true), + JoinType::RightAnti => (true, false), + JoinType::LeftMark => (false, true), + JoinType::RightMark => (true, false), + } + } + /// Does the join type support swapping inputs? pub fn supports_swap(&self) -> bool { matches!( @@ -113,6 +142,20 @@ impl JoinType { | JoinType::RightMark ) } + + /// Returns true when an empty build side necessarily produces an empty + /// result for this join type. + pub fn empty_build_side_produces_empty_result(self) -> bool { + matches!( + self, + JoinType::Inner + | JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::RightSemi + ) + } } impl Display for JoinType { diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 76c7b46e32737..996c563f0d8a2 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -23,14 +23,13 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] mod column; mod dfschema; mod functional_dependencies; mod join_type; mod param_value; -#[cfg(feature = "pyarrow")] -mod pyarrow; mod schema_reference; mod table_reference; mod unnest; @@ -51,6 +50,7 @@ pub mod instant; pub mod metadata; pub mod nested_struct; mod null_equality; +pub mod parquet_config; pub mod parsers; pub mod pruning; pub mod rounding; @@ -61,28 +61,30 @@ pub mod test_util; pub mod tree_node; pub mod types; pub mod utils; - /// Reexport arrow crate pub use arrow; pub use column::Column; pub use dfschema::{ - qualified_name, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, + DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, qualified_name, }; pub use diagnostic::Diagnostic; +pub use display::human_readable::{ + human_readable_count, human_readable_duration, human_readable_size, units, +}; pub use error::{ - field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, - SharedResult, + DataFusionError, Result, SchemaError, SharedResult, field_not_found, + unqualified_field_not_found, }; pub use file_options::file_type::{ - GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, - DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, GetExt, }; pub use functional_dependencies::{ + Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, aggregate_functional_dependencies, get_required_group_by_exprs_indices, - get_target_functional_dependencies, Constraint, Constraints, Dependency, - FunctionalDependence, FunctionalDependencies, + get_required_sort_exprs_indices, get_target_functional_dependencies, }; -use hashbrown::hash_map::DefaultHashBuilder; +use hashbrown::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; pub use nested_struct::cast_column; pub use null_equality::NullEquality; @@ -102,9 +104,9 @@ pub use utils::project_schema; // https://github.com/rust-lang/rust/pull/52234#issuecomment-976702997 #[doc(hidden)] pub use error::{ - _config_datafusion_err, _exec_datafusion_err, _internal_datafusion_err, - _not_impl_datafusion_err, _plan_datafusion_err, _resources_datafusion_err, - _substrait_datafusion_err, + _config_datafusion_err, _exec_datafusion_err, _ffi_datafusion_err, + _internal_datafusion_err, _not_impl_datafusion_err, _plan_datafusion_err, + _resources_datafusion_err, _substrait_datafusion_err, }; // The HashMap and HashSet implementations that should be used as the uniform defaults @@ -136,10 +138,10 @@ macro_rules! downcast_value { // Not public API. #[doc(hidden)] pub mod __private { - use crate::error::_internal_datafusion_err; use crate::Result; + use crate::error::_internal_datafusion_err; use arrow::array::Array; - use std::any::{type_name, Any}; + use std::any::{Any, type_name}; #[doc(hidden)] pub trait DowncastArrayHelper { @@ -190,7 +192,7 @@ mod tests { assert_starts_with( error.to_string(), - "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray" + "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray", ); } diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs index 3a10cc2b42f9f..d6d8fb7b0ed0c 100644 --- a/datafusion/common/src/metadata.rs +++ b/datafusion/common/src/metadata.rs @@ -17,10 +17,10 @@ use std::{collections::BTreeMap, sync::Arc}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use hashbrown::HashMap; -use crate::{error::_plan_err, DataFusionError, ScalarValue}; +use crate::{DataFusionError, ScalarValue, error::_plan_err}; /// A [`ScalarValue`] with optional [`FieldMetadata`] #[derive(Debug, Clone)] @@ -171,6 +171,10 @@ pub fn format_type_and_metadata( /// // Add any metadata from `FieldMetadata` to `Field` /// let updated_field = metadata.add_to_field(field); /// ``` +/// +/// For more background, please also see the [Implementing User Defined Types and Custom Metadata in DataFusion blog] +/// +/// [Implementing User Defined Types and Custom Metadata in DataFusion blog]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct FieldMetadata { /// The inner metadata of a literal expression, which is a map of string @@ -320,6 +324,16 @@ impl FieldMetadata { field.with_metadata(self.to_hashmap()) } + + /// Updates the metadata on the FieldRef with this metadata, if it is not empty. + pub fn add_to_field_ref(&self, mut field_ref: FieldRef) -> FieldRef { + if self.inner.is_empty() { + return field_ref; + } + + Arc::make_mut(&mut field_ref).set_metadata(self.to_hashmap()); + field_ref + } } impl From<&Field> for FieldMetadata { diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index d43816f75b0ed..cdd6215d08e2f 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{Result, _plan_err}; +use crate::error::{_plan_err, Result}; use arrow::{ - array::{new_null_array, Array, ArrayRef, StructArray}, - compute::{cast_with_options, CastOptions}, - datatypes::{DataType::Struct, Field, FieldRef}, + array::{ + Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray, + StructArray, downcast_integer, new_null_array, + }, + compute::{CastOptions, can_cast_types, cast_with_options}, + datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; /// Cast a struct column to match target struct fields, handling nested structs recursively. /// @@ -31,6 +34,7 @@ use std::sync::Arc; /// /// ## Field Matching Strategy /// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **No Positional Mapping**: Structs with no overlapping field names are rejected /// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type /// - **Missing Fields**: Target fields not present in the source are filled with null values /// - **Extra Fields**: Source fields not present in the target are ignored @@ -54,25 +58,42 @@ fn cast_struct_column( target_fields: &[Arc], cast_options: &CastOptions, ) -> Result { - if let Some(source_struct) = source_col.as_any().downcast_ref::() { - validate_struct_compatibility(source_struct.fields(), target_fields)?; + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + let source_fields = source_struct.fields(); + validate_struct_compatibility(source_fields, target_fields)?; let mut fields: Vec> = Vec::with_capacity(target_fields.len()); let mut arrays: Vec = Vec::with_capacity(target_fields.len()); let num_rows = source_col.len(); - for target_child_field in target_fields { + // Iterate target fields and pick source child by name when present. + for target_child_field in target_fields.iter() { fields.push(Arc::clone(target_child_field)); - match source_struct.column_by_name(target_child_field.name()) { + + let source_child_opt = + source_struct.column_by_name(target_child_field.name()); + + match source_child_opt { Some(source_child_col) => { - let adapted_child = - cast_column(source_child_col, target_child_field, cast_options) - .map_err(|e| { - e.context(format!( - "While casting struct field '{}'", - target_child_field.name() - )) - })?; + let adapted_child = cast_column( + source_child_col, + target_child_field.data_type(), + cast_options, + ) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; arrays.push(adapted_child); } None => { @@ -112,18 +133,17 @@ fn cast_struct_column( /// ``` /// use arrow::array::{ArrayRef, Int64Array}; /// use arrow::compute::CastOptions; -/// use arrow::datatypes::{DataType, Field}; +/// use arrow::datatypes::DataType; /// use datafusion_common::nested_struct::cast_column; /// use std::sync::Arc; /// /// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); -/// let target = Field::new("ints", DataType::Int32, true); /// // Permit lossy conversions by producing NULL on overflow instead of erroring /// let options = CastOptions { /// safe: true, /// ..Default::default() /// }; -/// let result = cast_column(&source, &target, &options).unwrap(); +/// let result = cast_column(&source, &DataType::Int32, &options).unwrap(); /// assert!(result.is_null(1)); /// ``` /// @@ -136,7 +156,7 @@ fn cast_struct_column( /// /// # Arguments /// * `source_col` - The source array to cast -/// * `target_field` - The target field definition (including type and metadata) +/// * `target_type` - The target data type to cast to /// * `cast_options` - Options that govern strictness and formatting of the cast /// /// # Returns @@ -150,18 +170,139 @@ fn cast_struct_column( /// - Invalid data type combinations are encountered pub fn cast_column( source_col: &ArrayRef, - target_field: &Field, + target_type: &DataType, cast_options: &CastOptions, ) -> Result { - match target_field.data_type() { - Struct(target_fields) => { + match (source_col.data_type(), target_type) { + (_, Struct(target_fields)) => { cast_struct_column(source_col, target_fields, cast_options) } - _ => Ok(cast_with_options( + (DataType::List(_), DataType::List(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::LargeList(_), DataType::LargeList(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::ListView(_), DataType::ListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + (DataType::LargeListView(_), DataType::LargeListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + ( + DataType::Dictionary(source_key_type, _), + DataType::Dictionary(target_key_type, target_value_type), + ) => cast_dictionary_column( source_col, - target_field.data_type(), + source_key_type, + target_key_type, + target_value_type, cast_options, - )?), + ), + _ => Ok(cast_with_options(source_col, target_type, cast_options)?), + } +} + +fn cast_list_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListArray::::new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + cast_values, + source_list.nulls().cloned(), + ); + Ok(Arc::new(result)) +} + +fn cast_list_view_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list view array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListViewArray::::try_new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + source_list.sizes().clone(), + cast_values, + source_list.nulls().cloned(), + )?; + Ok(Arc::new(result)) +} + +fn cast_dictionary_column( + source_col: &ArrayRef, + source_key_type: &DataType, + target_key_type: &DataType, + target_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + // Dispatch on source key type to access keys/values, then recursively + // cast values. Rebuild with the source key type first. + macro_rules! cast_dict_values { + ($t:ty) => {{ + let source_dict = source_col + .as_any() + .downcast_ref::>() + .expect("downcast must succeed"); + let cast_values = + cast_column(source_dict.values(), target_value_type, cast_options)?; + Ok(Arc::new(DictionaryArray::<$t>::new( + source_dict.keys().clone(), + cast_values, + )) as ArrayRef) + }}; + } + + let result: Result = downcast_integer! { + source_key_type => (cast_dict_values), + k => _plan_err!("Unsupported dictionary key type: {k}") + }; + let result = result?; + + // If key types differ, delegate key casting to Arrow. + if source_key_type != target_key_type { + let target_dict_type = DataType::Dictionary( + Box::new(target_key_type.clone()), + Box::new(target_value_type.clone()), + ); + Ok(cast_with_options(&result, &target_dict_type, cast_options)?) + } else { + Ok(result) } } @@ -200,10 +341,20 @@ pub fn cast_column( /// // Target: {a: binary} /// // Result: Err(...) - string cannot cast to binary /// ``` +/// pub fn validate_struct_compatibility( source_fields: &[FieldRef], target_fields: &[FieldRef], ) -> Result<()> { + let has_overlap = has_one_of_more_common_fields(source_fields, target_fields); + if !has_overlap { + return _plan_err!( + "Cannot cast struct with {} fields to {} fields because there is no field name overlap", + source_fields.len(), + target_fields.len() + ); + } + // Check compatibility for each target field for target_field in target_fields { // Look for matching field in source by name @@ -211,55 +362,156 @@ pub fn validate_struct_compatibility( .iter() .find(|f| f.name() == target_field.name()) { - // Ensure nullability is compatible. It is invalid to cast a nullable - // source field to a non-nullable target field as this may discard - // null values. - if source_field.is_nullable() && !target_field.is_nullable() { + validate_field_compatibility(source_field, target_field)?; + } else { + // Target field is missing from source + // If it's non-nullable, we cannot fill it with NULL + if !target_field.is_nullable() { return _plan_err!( - "Cannot cast nullable struct field '{}' to non-nullable field", + "Cannot cast struct: target field '{}' is non-nullable but missing from source. \ + Cannot fill with NULL.", target_field.name() ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs - (Struct(source_nested), Struct(target_nested)) => { - validate_struct_compatibility(source_nested, target_nested)?; - } - // For non-struct types, use the existing castability check - _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { - return _plan_err!( - "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() - ); - } - } - } } - // Missing fields in source are OK - they'll be filled with nulls } // Extra fields in source are OK - they'll be ignored Ok(()) } +fn validate_field_compatibility( + source_field: &Field, + target_field: &Field, +) -> Result<()> { + if source_field.data_type() == &DataType::Null { + // Validate that target allows nulls before returning early. + // It is invalid to cast a NULL source field to a non-nullable target field. + if !target_field.is_nullable() { + return _plan_err!( + "Cannot cast NULL struct field '{}' to non-nullable field '{}'", + source_field.name(), + target_field.name() + ); + } + return Ok(()); + } + + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + + validate_data_type_compatibility( + target_field.name(), + source_field.data_type(), + target_field.data_type(), + ) +} + +/// Validates that `source_type` can be cast to `target_type`, recursively +/// handling container types that wrap structs. +pub fn validate_data_type_compatibility( + field_name: &str, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + match (source_type, target_type) { + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + validate_field_compatibility(s, t)?; + } + (DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => { + if !can_cast_types(s_key, t_key) { + return _plan_err!( + "Cannot cast dictionary key type {} to {} for field '{}'", + s_key, + t_key, + field_name + ); + } + validate_data_type_compatibility(field_name, s_val, t_val)?; + } + _ => { + if !can_cast_types(source_type, target_type) { + return _plan_err!( + "Cannot cast struct field '{}' from type {} to type {}", + field_name, + source_type, + target_type + ); + } + } + } + Ok(()) +} + +/// Returns true if casting from `source_type` to `target_type` requires +/// name-based nested struct casting logic, rather than Arrow's standard cast. +/// +/// This is the case when both types are struct types, or both are the same +/// container type (List, LargeList, ListView, LargeListView, Dictionary) wrapping +/// types that recursively contain structs. +/// +/// Use this predicate at both planning time (to decide whether to apply struct +/// compatibility validation) and execution time (to decide whether to route +/// through [`cast_column`] instead of Arrow's generic cast). +pub fn requires_nested_struct_cast( + source_type: &DataType, + target_type: &DataType, +) -> bool { + match (source_type, target_type) { + (Struct(_), Struct(_)) => true, + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + requires_nested_struct_cast(s.data_type(), t.data_type()) + } + (DataType::Dictionary(_, s_val), DataType::Dictionary(_, t_val)) => { + requires_nested_struct_cast(s_val, t_val) + } + _ => false, + } +} + +/// Check if two field lists have at least one common field by name. +/// +/// This is useful for validating struct compatibility when casting between structs, +/// ensuring that source and target fields have overlapping names. +pub fn has_one_of_more_common_fields( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> bool { + let source_names: HashSet<&str> = source_fields + .iter() + .map(|field| field.name().as_str()) + .collect(); + target_fields + .iter() + .any(|field| source_names.contains(field.name().as_str())) +} + #[cfg(test)] mod tests { - use super::*; - use crate::format::DEFAULT_CAST_OPTIONS; + use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; use arrow::{ array::{ - BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, StringArray, StringBuilder, + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, ListViewArray, + MapArray, MapBuilder, NullArray, StringArray, StringBuilder, }, - buffer::NullBuffer, + buffer::{NullBuffer, ScalarBuffer}, datatypes::{DataType, Field, FieldRef, Int32Type}, }; /// Macro to extract and downcast a column from a StructArray @@ -302,7 +554,9 @@ mod tests { fn test_cast_simple_column() { let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let target_field = field("ints", DataType::Int64); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 3); assert_eq!(result.value(0), 1); @@ -320,14 +574,15 @@ mod tests { safe: false, ..DEFAULT_CAST_OPTIONS }; - assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + assert!(cast_column(&source, target_field.data_type(), &safe_opts).is_err()); let unsafe_opts = CastOptions { // safe: true - return Null for failure safe: true, ..DEFAULT_CAST_OPTIONS }; - let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &unsafe_opts).unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.value(0), 1); assert!(result.is_null(1)); @@ -348,7 +603,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.fields().len(), 2); let a_result = get_column_as!(&struct_array, "a", Int32Array); @@ -366,7 +622,8 @@ mod tests { let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast column of type")); @@ -386,7 +643,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast struct field 'a'")); @@ -428,11 +686,14 @@ mod tests { #[test] fn test_validate_struct_compatibility_missing_field_in_source() { - // Source struct: {field2: String} (missing field1) - let source_fields = vec![arc_field("field2", DataType::Utf8)]; + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; - // Target struct: {field1: Int32} - let target_fields = vec![arc_field("field1", DataType::Int32)]; + // Target struct: {field1: Int32, field2: Utf8} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; // Should be OK - missing fields will be filled with nulls let result = validate_struct_compatibility(&source_fields, &target_fields); @@ -455,6 +716,20 @@ mod tests { assert!(result.is_ok()); } + #[test] + fn test_validate_struct_compatibility_no_overlap_mismatch_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Int32), + ]; + let target_fields = vec![arc_field("alpha", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + #[test] fn test_cast_struct_parent_nulls_retained() { let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; @@ -466,7 +741,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int64)]); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.null_count(), 1); assert!(struct_array.is_valid(0)); @@ -525,6 +801,117 @@ mod tests { assert!(error_msg.contains("non-nullable")); } + #[test] + fn test_validate_struct_compatibility_by_name() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field2: String, field1: Int64} + let target_fields = vec![ + arc_field("field2", DataType::Utf8), + arc_field("field1", DataType::Int64), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_with_type_mismatch() { + // Source struct: {field1: Binary} + let source_fields = vec![arc_field("field1", DataType::Binary)]; + + // Target struct: {field1: Int32} (incompatible type) + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct field 'field1' from type Binary to type Int32" + ); + } + + #[test] + fn test_validate_struct_compatibility_no_overlap_equal_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Utf8), + ]; + + let target_fields = vec![ + arc_field("alpha", DataType::Int32), + arc_field("beta", DataType::Utf8), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_validate_struct_compatibility_mixed_name_overlap() { + // Source struct: {a: Int32, b: String, extra: Boolean} + let source_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + arc_field("extra", DataType::Boolean), + ]; + + // Target struct: {b: String, a: Int64, c: Float32} + // Name overlap with a and b, missing c (nullable) + let target_fields = vec![ + arc_field("b", DataType::Utf8), + arc_field("a", DataType::Int64), + arc_field("c", DataType::Float32), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_missing_required_field() { + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32, field2: Int32 non-nullable} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + Arc::new(non_null_field("field2", DataType::Int32)), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL." + ); + } + + #[test] + fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() { + // Source struct: {a: Int32} (only one field) + let source_fields = vec![arc_field("a", DataType::Int32)]; + + // Target struct: {a: Int32, b: String} (two fields, but 'a' overlaps) + let target_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + ]; + + // This should succeed - partial overlap means by-name mapping + // and missing field 'b' is nullable + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + #[test] fn test_cast_nested_struct_with_extra_and_missing_fields() { // Source inner struct has fields a, b, extra @@ -565,7 +952,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let outer = result.as_any().downcast_ref::().unwrap(); let inner = get_column_as!(&outer, "inner", StructArray); assert_eq!(inner.fields().len(), 3); @@ -585,6 +973,34 @@ mod tests { assert!(missing.is_null(1)); } + #[test] + fn test_cast_null_struct_field_to_nested_struct() { + let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("inner", DataType::Null), + Arc::clone(&null_inner), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "outer", + vec![struct_field("inner", vec![field("a", DataType::Int32)])], + ); + + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.len(), 2); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + + let inner_a = get_column_as!(inner, "a", Int32Array); + assert!(inner_a.is_null(0)); + assert!(inner_a.is_null(1)); + } + #[test] fn test_cast_struct_with_array_and_map_fields() { // Array field with second row null @@ -654,7 +1070,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let arr = get_column_as!(&struct_array, "arr", ListArray); @@ -693,7 +1110,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let b_col = get_column_as!(&struct_array, "b", Int64Array); @@ -704,4 +1122,218 @@ mod tests { assert_eq!(a_col.value(0), 1); assert_eq!(a_col.value(1), 2); } + + #[test] + fn test_cast_struct_no_overlap_rejected() { + let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; + let second = + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("left", DataType::Int32), first), + (arc_field("right", DataType::Utf8), second), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int64), field("b", DataType::Utf8)], + ); + + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_cast_struct_missing_non_nullable_field_fails() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' (nullable) and 'b' (non-nullable) + let target_field = struct_field( + "s", + vec![ + field("a", DataType::Int32), + non_null_field("b", DataType::Int32), + ], + ); + + // Should fail because 'b' is non-nullable but missing from source + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("target field 'b' is non-nullable but missing from source"), + "Unexpected error: {err}" + ); + } + + #[test] + fn test_cast_struct_missing_nullable_field_succeeds() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' and 'b' (both nullable) + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Int32)], + ); + + // Should succeed - 'b' is nullable so can be filled with NULL + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + + let b_col = get_column_as!(&struct_array, "b", Int32Array); + assert!(b_col.is_null(0)); + assert!(b_col.is_null(1)); + } + + #[test] + fn test_validate_dictionary_value_evolution() { + let source_inner = struct_type(vec![field("a", DataType::Int32)]); + let target_inner = struct_type(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + ]); + let source = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner)); + let target = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner)); + assert!(validate_data_type_compatibility("col", &source, &target).is_ok()); + } + + #[test] + fn test_cast_dictionary_struct_value() { + // Build a Dictionary and cast to + // Dictionary (field added, type widened). + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef, + )]); + // keys: [0, null, 1] mapping into the 2-element struct values array. + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let source_dict = DictionaryArray::::new(keys, Arc::new(struct_arr)); + let source_col: ArrayRef = Arc::new(source_dict); + + let target_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ])), + ); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_dict = result + .as_any() + .downcast_ref::>() + .unwrap(); + + assert!(result_dict.is_valid(0)); + assert!(result_dict.is_null(1)); + assert!(result_dict.is_valid(2)); + + let struct_values = result_dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[10, 20]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_cast_list_view_struct() { + // Build a ListView and cast to + // ListView. + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + + let source_field = + arc_field("item", struct_type(vec![field("a", DataType::Int32)])); + let target_field = arc_field( + "item", + struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ]), + ); + + // Two list-view entries: [0..2] and [2..3] + let list_view = ListViewArray::new( + source_field, + ScalarBuffer::from(vec![0i32, 2]), + ScalarBuffer::from(vec![2i32, 1]), + Arc::new(struct_arr), + None, + ); + let source_col: ArrayRef = Arc::new(list_view); + + let target_type = DataType::ListView(target_field); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_lv = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_lv.len(), 2); + + let struct_values = result_lv + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[1, 2, 3]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_requires_nested_struct_cast() { + let s1 = struct_type(vec![field("a", DataType::Int32)]); + let s2 = struct_type(vec![field("a", DataType::Int64)]); + + assert!(requires_nested_struct_cast(&s1, &s2)); + assert!(requires_nested_struct_cast( + &DataType::List(arc_field("item", s1.clone())), + &DataType::List(arc_field("item", s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s1.clone())), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::ListView(arc_field("item", s1)), + &DataType::ListView(arc_field("item", s2)), + )); + + // Non-struct types should return false. + assert!(!requires_nested_struct_cast( + &DataType::Int32, + &DataType::Int64 + )); + assert!(!requires_nested_struct_cast( + &DataType::List(arc_field("item", DataType::Int32)), + &DataType::List(arc_field("item", DataType::Int64)), + )); + } } diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index ebf68e4dd210d..0fac6b529eb0f 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -16,7 +16,7 @@ // under the License. use crate::error::{_plan_datafusion_err, _plan_err}; -use crate::metadata::{check_metadata_with_storage_equal, ScalarAndMetadata}; +use crate::metadata::{ScalarAndMetadata, check_metadata_with_storage_equal}; use crate::{Result, ScalarValue}; use arrow::datatypes::{DataType, Field, FieldRef}; use std::collections::HashMap; diff --git a/datafusion/common/src/parquet_config.rs b/datafusion/common/src/parquet_config.rs new file mode 100644 index 0000000000000..9d6d7a88566a7 --- /dev/null +++ b/datafusion/common/src/parquet_config.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{self, Display}; +use std::str::FromStr; + +use crate::config::{ConfigField, Visit}; +use crate::error::{DataFusionError, Result}; + +/// Parquet writer version options for controlling the Parquet file format version +/// +/// This enum validates parquet writer version values at configuration time, +/// ensuring only valid versions ("1.0" or "2.0") can be set via `SET` commands +/// or proto deserialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DFParquetWriterVersion { + /// Parquet format version 1.0 + #[default] + V1_0, + /// Parquet format version 2.0 + V2_0, +} + +/// Implement parsing strings to `DFParquetWriterVersion` +impl FromStr for DFParquetWriterVersion { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "1.0" => Ok(DFParquetWriterVersion::V1_0), + "2.0" => Ok(DFParquetWriterVersion::V2_0), + other => Err(DataFusionError::Configuration(format!( + "Invalid parquet writer version: {other}. Expected one of: 1.0, 2.0" + ))), + } + } +} + +impl Display for DFParquetWriterVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + DFParquetWriterVersion::V1_0 => "1.0", + DFParquetWriterVersion::V2_0 => "2.0", + }; + write!(f, "{s}") + } +} + +impl ConfigField for DFParquetWriterVersion { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = DFParquetWriterVersion::from_str(value)?; + Ok(()) + } +} + +/// Convert `DFParquetWriterVersion` to parquet crate's `WriterVersion` +/// +/// This conversion is infallible since `DFParquetWriterVersion` only contains +/// valid values that have been validated at configuration time. +#[cfg(feature = "parquet")] +impl From for parquet::file::properties::WriterVersion { + fn from(value: DFParquetWriterVersion) -> Self { + match value { + DFParquetWriterVersion::V1_0 => { + parquet::file::properties::WriterVersion::PARQUET_1_0 + } + DFParquetWriterVersion::V2_0 => { + parquet::file::properties::WriterVersion::PARQUET_2_0 + } + } + } +} + +/// Convert parquet crate's `WriterVersion` to `DFParquetWriterVersion` +/// +/// This is used when converting from existing parquet writer properties, +/// such as when reading from proto or test code. +#[cfg(feature = "parquet")] +impl From for DFParquetWriterVersion { + fn from(version: parquet::file::properties::WriterVersion) -> Self { + match version { + parquet::file::properties::WriterVersion::PARQUET_1_0 => { + DFParquetWriterVersion::V1_0 + } + parquet::file::properties::WriterVersion::PARQUET_2_0 => { + DFParquetWriterVersion::V2_0 + } + } + } +} diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index cd3d607dacd88..6b930d110f47b 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -73,3 +73,59 @@ impl CompressionTypeVariant { !matches!(self, &Self::UNCOMPRESSED) } } + +/// CSV quote style +/// +/// Controls when fields are quoted when writing CSV files. +/// Corresponds to [`arrow::csv::QuoteStyle`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum CsvQuoteStyle { + /// Quote all fields + Always, + /// Only quote fields when necessary (default) + #[default] + Necessary, + /// Quote all non-numeric fields + NonNumeric, + /// Never quote fields + Never, +} + +impl FromStr for CsvQuoteStyle { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "always" => Ok(Self::Always), + "necessary" => Ok(Self::Necessary), + "non_numeric" | "nonnumeric" => Ok(Self::NonNumeric), + "never" => Ok(Self::Never), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported CSV quote style {s}" + ))), + } + } +} + +impl From for arrow::csv::QuoteStyle { + fn from(style: CsvQuoteStyle) -> Self { + match style { + CsvQuoteStyle::Always => Self::Always, + CsvQuoteStyle::NonNumeric => Self::NonNumeric, + CsvQuoteStyle::Never => Self::Never, + CsvQuoteStyle::Necessary => Self::Necessary, + } + } +} + +impl Display for CsvQuoteStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { + Self::Always => "Always", + Self::Necessary => "Necessary", + Self::NonNumeric => "NonNumeric", + Self::Never => "Never", + }; + write!(f, "{str}") + } +} diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs index 48750e3c995c4..ebae23f0723a1 100644 --- a/datafusion/common/src/pruning.rs +++ b/datafusion/common/src/pruning.rs @@ -95,15 +95,17 @@ pub trait PruningStatistics { /// [`UInt64Array`]: arrow::array::UInt64Array fn null_counts(&self, column: &Column) -> Option; - /// Return the number of rows for the named column in each container - /// as an [`UInt64Array`]. + /// Return the number of rows in each container as an [`UInt64Array`]. + /// + /// Row counts are container-level (not column-specific) — the value + /// is the same regardless of which column is being considered. /// /// See [`Self::min_values`] for when to return `None` and null values. /// /// Note: the returned array must contain [`Self::num_containers`] rows /// /// [`UInt64Array`]: arrow::array::UInt64Array - fn row_counts(&self, column: &Column) -> Option; + fn row_counts(&self) -> Option; /// Returns [`BooleanArray`] where each row represents information known /// about specific literal `values` in a column. @@ -121,6 +123,7 @@ pub trait PruningStatistics { /// container, return `None` (the default). /// /// Note: the returned array must contain [`Self::num_containers`] rows + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn contained( &self, column: &Column, @@ -135,6 +138,10 @@ pub trait PruningStatistics { /// This feeds into [`CompositePruningStatistics`] to allow pruning /// with filters that depend both on partition columns and data columns /// (e.g. `WHERE partition_col = data_col`). +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. Use `replace_columns_with_literals` from `datafusion-physical-expr-adapter` to substitute partition column values before pruning. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." +)] #[derive(Clone)] pub struct PartitionPruningStatistics { /// Values for each column for each container. @@ -156,6 +163,7 @@ pub struct PartitionPruningStatistics { partition_schema: SchemaRef, } +#[expect(deprecated)] impl PartitionPruningStatistics { /// Create a new instance of [`PartitionPruningStatistics`]. /// @@ -169,6 +177,36 @@ impl PartitionPruningStatistics { /// This must **not** be the schema of the entire file or table: /// instead it must only be the schema of the partition columns, /// in the same order as the values in `partition_values`. + /// + /// # Example + /// + /// To create [`PartitionPruningStatistics`] for two partition columns `a` and `b`, + /// for three containers like this: + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + /// | 5 | 6 | + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_common::ScalarValue; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::pruning::PartitionPruningStatistics; + /// + /// let partition_values = vec![ + /// vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + /// vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + /// vec![ScalarValue::from(5i32), ScalarValue::from(6i32)], + /// ]; + /// let partition_fields = vec![ + /// Arc::new(Field::new("a", DataType::Int32, false)), + /// Arc::new(Field::new("b", DataType::Int32, false)), + /// ]; + /// let partition_stats = + /// PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap(); + /// ``` pub fn try_new( partition_values: Vec>, partition_fields: Vec, @@ -202,6 +240,7 @@ impl PartitionPruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for PartitionPruningStatistics { fn min_values(&self, column: &Column) -> Option { let index = self.partition_schema.index_of(column.name()).ok()?; @@ -228,7 +267,7 @@ impl PruningStatistics for PartitionPruningStatistics { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self) -> Option { None } @@ -245,7 +284,7 @@ impl PruningStatistics for PartitionPruningStatistics { match acc { None => Some(Some(eq_result)), Some(acc_array) => { - arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + arrow::compute::kernels::boolean::or_kleene(&acc_array, &eq_result) .map(Some) .ok() } @@ -361,11 +400,7 @@ impl PruningStatistics for PrunableStatistics { } } - fn row_counts(&self, column: &Column) -> Option { - // If the column does not exist in the schema, return None - if self.schema.index_of(column.name()).is_err() { - return None; - } + fn row_counts(&self) -> Option { if self .statistics .iter() @@ -409,10 +444,15 @@ impl PruningStatistics for PrunableStatistics { /// the first one is returned without any regard for completeness or accuracy. /// That is: if the first statistics has information for a column, even if it is incomplete, /// that is returned even if a later statistics has more complete information. +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. It may be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first. Please open an issue if you have a use case for it." +)] pub struct CompositePruningStatistics { pub statistics: Vec>, } +#[expect(deprecated)] impl CompositePruningStatistics { /// Create a new instance of [`CompositePruningStatistics`] from /// a vector of [`PruningStatistics`]. @@ -427,6 +467,7 @@ impl CompositePruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for CompositePruningStatistics { fn min_values(&self, column: &Column) -> Option { for stats in &self.statistics { @@ -459,9 +500,9 @@ impl PruningStatistics for CompositePruningStatistics { None } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self) -> Option { for stats in &self.statistics { - if let Some(array) = stats.row_counts(column) { + if let Some(array) = stats.row_counts() { return Some(array); } } @@ -483,18 +524,26 @@ impl PruningStatistics for CompositePruningStatistics { } #[cfg(test)] +#[expect(deprecated)] +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key mod tests { use crate::{ - cast::{as_int32_array, as_uint64_array}, ColumnStatistics, + cast::{as_int32_array, as_uint64_array}, }; use super::*; use arrow::datatypes::{DataType, Field}; use std::sync::Arc; - #[test] - fn test_partition_pruning_statistics() { + /// return a PartitionPruningStatistics for two columns 'a' and 'b' + /// and the following stats + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + fn partition_pruning_statistics_setup() -> PartitionPruningStatistics { let partition_values = vec![ vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], @@ -503,18 +552,21 @@ mod tests { Arc::new(Field::new("a", DataType::Int32, false)), Arc::new(Field::new("b", DataType::Int32, false)), ]; - let partition_stats = - PartitionPruningStatistics::try_new(partition_values, partition_fields) - .unwrap(); + PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap() + } + + #[test] + fn test_partition_pruning_statistics() { + let partition_stats = partition_pruning_statistics_setup(); let column_a = Column::new_unqualified("a"); let column_b = Column::new_unqualified("b"); // Partition values don't know anything about nulls or row counts assert!(partition_stats.null_counts(&column_a).is_none()); - assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.row_counts().is_none()); assert!(partition_stats.null_counts(&column_b).is_none()); - assert!(partition_stats.row_counts(&column_b).is_none()); + assert!(partition_stats.row_counts().is_none()); // Min/max values are the same as the partition values let min_values_a = @@ -560,6 +612,85 @@ mod tests { assert_eq!(partition_stats.num_containers(), 2); } + #[test] + fn test_partition_pruning_statistics_multiple_positive_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, so they both only contain values from 1 and 3 + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(3i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, true]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_multiple_negative_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, + // so the first contains ONLY values from 1,2 + // but the second does not + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(2i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_null_in_values() { + let partition_values = vec![ + vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ], + vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ScalarValue::from(6i32), + ], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + Arc::new(Field::new("c", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + let column_c = Column::new_unqualified("c"); + + let values_a = HashSet::from([ScalarValue::from(1i32), ScalarValue::Int32(None)]); + let contained_a = partition_stats.contained(&column_a, &values_a).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_value(true); + builder.append_null(); + let expected_contained_a = builder.finish(); + assert_eq!(contained_a, expected_contained_a); + + // First match creates a NULL boolean array + // The accumulator should update the value to true for the second value + let values_b = HashSet::from([ScalarValue::Int32(None), ScalarValue::from(5i32)]); + let contained_b = partition_stats.contained(&column_b, &values_b).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_null(); + builder.append_value(true); + let expected_contained_b = builder.finish(); + assert_eq!(contained_b, expected_contained_b); + + // All matches are null, contained should return None + let values_c = HashSet::from([ScalarValue::Int32(None)]); + let contained_c = partition_stats.contained(&column_c, &values_c); + assert!(contained_c.is_none()); + } + #[test] fn test_partition_pruning_statistics_empty() { let partition_values = vec![]; @@ -576,9 +707,9 @@ mod tests { // Partition values don't know anything about nulls or row counts assert!(partition_stats.null_counts(&column_a).is_none()); - assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.row_counts().is_none()); assert!(partition_stats.null_counts(&column_b).is_none()); - assert!(partition_stats.row_counts(&column_b).is_none()); + assert!(partition_stats.row_counts().is_none()); // Min/max values are all missing assert!(partition_stats.min_values(&column_a).is_none()); @@ -681,13 +812,13 @@ mod tests { assert_eq!(null_counts_b, expected_null_counts_b); // Row counts are the same as the statistics - let row_counts_a = as_uint64_array(&pruning_stats.row_counts(&column_a).unwrap()) + let row_counts_a = as_uint64_array(&pruning_stats.row_counts().unwrap()) .unwrap() .into_iter() .collect::>(); let expected_row_counts_a = vec![Some(100), Some(200)]; assert_eq!(row_counts_a, expected_row_counts_a); - let row_counts_b = as_uint64_array(&pruning_stats.row_counts(&column_b).unwrap()) + let row_counts_b = as_uint64_array(&pruning_stats.row_counts().unwrap()) .unwrap() .into_iter() .collect::>(); @@ -712,7 +843,7 @@ mod tests { // This is debatable, personally I think `row_count` should not take a `Column` as an argument // at all since all columns should have the same number of rows. // But for now we just document the current behavior in this test. - let row_counts_c = as_uint64_array(&pruning_stats.row_counts(&column_c).unwrap()) + let row_counts_c = as_uint64_array(&pruning_stats.row_counts().unwrap()) .unwrap() .into_iter() .collect::>(); @@ -720,12 +851,13 @@ mod tests { assert_eq!(row_counts_c, expected_row_counts_c); assert!(pruning_stats.contained(&column_c, &values).is_none()); - // Test with a column that doesn't exist + // Test with a column that doesn't exist — column-specific stats + // return None, but row_counts is container-level and still available let column_d = Column::new_unqualified("d"); assert!(pruning_stats.min_values(&column_d).is_none()); assert!(pruning_stats.max_values(&column_d).is_none()); assert!(pruning_stats.null_counts(&column_d).is_none()); - assert!(pruning_stats.row_counts(&column_d).is_none()); + assert!(pruning_stats.row_counts().is_some()); assert!(pruning_stats.contained(&column_d, &values).is_none()); } @@ -753,8 +885,8 @@ mod tests { assert!(pruning_stats.null_counts(&column_b).is_none()); // Row counts are all missing - assert!(pruning_stats.row_counts(&column_a).is_none()); - assert!(pruning_stats.row_counts(&column_b).is_none()); + assert!(pruning_stats.row_counts().is_none()); + assert!(pruning_stats.row_counts().is_none()); // Contained values are all empty let values = HashSet::from([ScalarValue::from(1i32)]); @@ -894,13 +1026,11 @@ mod tests { let expected_null_counts_col_x = vec![Some(0), Some(10)]; assert_eq!(null_counts_col_x, expected_null_counts_col_x); - // Test row counts - only available from file statistics - assert!(composite_stats.row_counts(&part_a).is_none()); - let row_counts_col_x = - as_uint64_array(&composite_stats.row_counts(&col_x).unwrap()) - .unwrap() - .into_iter() - .collect::>(); + // Test row counts — container-level, available from file statistics + let row_counts_col_x = as_uint64_array(&composite_stats.row_counts().unwrap()) + .unwrap() + .into_iter() + .collect::>(); let expected_row_counts = vec![Some(100), Some(200)]; assert_eq!(row_counts_col_x, expected_row_counts); @@ -913,12 +1043,13 @@ mod tests { // File statistics don't implement contained assert!(composite_stats.contained(&col_x, &values).is_none()); - // Non-existent column should return None for everything + // Non-existent column should return None for column-specific stats, + // but row_counts is container-level and still available let non_existent = Column::new_unqualified("non_existent"); assert!(composite_stats.min_values(&non_existent).is_none()); assert!(composite_stats.max_values(&non_existent).is_none()); assert!(composite_stats.null_counts(&non_existent).is_none()); - assert!(composite_stats.row_counts(&non_existent).is_none()); + assert!(composite_stats.row_counts().is_some()); assert!(composite_stats.contained(&non_existent, &values).is_none()); // Verify num_containers matches @@ -1022,7 +1153,7 @@ mod tests { let expected_null_counts = vec![Some(0), Some(5)]; assert_eq!(null_counts, expected_null_counts); - let row_counts = as_uint64_array(&composite_stats.row_counts(&col_a).unwrap()) + let row_counts = as_uint64_array(&composite_stats.row_counts().unwrap()) .unwrap() .into_iter() .collect::>(); @@ -1062,11 +1193,10 @@ mod tests { let expected_null_counts = vec![Some(10), Some(20)]; assert_eq!(null_counts, expected_null_counts); - let row_counts = - as_uint64_array(&composite_stats_reversed.row_counts(&col_a).unwrap()) - .unwrap() - .into_iter() - .collect::>(); + let row_counts = as_uint64_array(&composite_stats_reversed.row_counts().unwrap()) + .unwrap() + .into_iter() + .collect::>(); let expected_row_counts = vec![Some(1000), Some(2000)]; assert_eq!(row_counts, expected_row_counts); } diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs deleted file mode 100644 index 18c6739735ff7..0000000000000 --- a/datafusion/common/src/pyarrow.rs +++ /dev/null @@ -1,169 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Conversions between PyArrow and DataFusion types - -use arrow::array::{Array, ArrayData}; -use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use pyo3::exceptions::PyException; -use pyo3::prelude::PyErr; -use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python}; - -use crate::{DataFusionError, ScalarValue}; - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - -impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { - let py = value.py(); - let typ = value.getattr("type")?; - let val = value.call_method0("as_py")?; - - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [val])?; - let array = factory.call1((args, typ))?; - - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0)?; - - Ok(scalar) - } -} - -impl ToPyArrow for ScalarValue { - fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult> { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - let pyscalar = pyarray.call_method1("__getitem__", (0,))?; - - Ok(pyscalar) - } -} - -impl<'source> FromPyObject<'source> for ScalarValue { - fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult { - Self::from_pyarrow_bound(value) - } -} - -impl<'source> IntoPyObject<'source> for ScalarValue { - type Target = PyAny; - - type Output = Bound<'source, Self::Target>; - - type Error = PyErr; - - fn into_pyobject(self, py: Python<'source>) -> Result { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - pyarray.call_method1("__getitem__", (0,)) - } -} - -#[cfg(test)] -mod tests { - use pyo3::ffi::c_str; - use pyo3::py_run; - use pyo3::types::PyDict; - use pyo3::Python; - - use super::*; - - fn init_python() { - Python::initialize(); - Python::attach(|py| { - if py.run(c_str!("import pyarrow"), None, None).is_err() { - let locals = PyDict::new(py); - py.run( - c_str!( - "import sys; executable = sys.executable; python_path = sys.path" - ), - None, - Some(&locals), - ) - .expect("Couldn't get python info"); - let executable = locals.get_item("executable").unwrap(); - let executable: String = executable.extract().unwrap(); - - let python_path = locals.get_item("python_path").unwrap(); - let python_path: Vec = python_path.extract().unwrap(); - - panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ - HINT: try `pip install pyarrow`\n\ - NOTE: On Mac OS, you must compile against a Framework Python \ - (default in python.org installers and brew, but not pyenv)\n\ - NOTE: On Mac OS, PYO3 might point to incorrect Python library \ - path when using virtual environments. Try \ - `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n") - } - }) - } - - #[test] - fn test_roundtrip() { - init_python(); - - let example_scalars = [ - ScalarValue::Boolean(Some(true)), - ScalarValue::Int32(Some(23)), - ScalarValue::Float64(Some(12.34)), - ScalarValue::from("Hello!"), - ScalarValue::Date32(Some(1234)), - ]; - - Python::attach(|py| { - for scalar in example_scalars.iter() { - let result = - ScalarValue::from_pyarrow_bound(&scalar.to_pyarrow(py).unwrap()) - .unwrap(); - assert_eq!(scalar, &result); - } - }); - } - - #[test] - fn test_py_scalar() -> PyResult<()> { - init_python(); - - Python::attach(|py| -> PyResult<()> { - let scalar_float = ScalarValue::Float64(Some(12.34)); - let py_float = scalar_float - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_float, "assert py_float == 12.34"); - - let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); - let py_string = scalar_string - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_string, "assert py_string == 'Hello!'"); - - Ok(()) - }) - } -} diff --git a/datafusion/common/src/rounding.rs b/datafusion/common/src/rounding.rs index 95eefd3235b5f..1796143d7cf1a 100644 --- a/datafusion/common/src/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -47,7 +47,7 @@ extern crate libc; any(target_arch = "x86_64", target_arch = "aarch64"), not(target_os = "windows") ))] -extern "C" { +unsafe extern "C" { fn fesetround(round: i32); fn fegetround() -> i32; } diff --git a/datafusion/common/src/scalar/cache.rs b/datafusion/common/src/scalar/cache.rs index f1476a518774b..5b1ad4e4ede01 100644 --- a/datafusion/common/src/scalar/cache.rs +++ b/datafusion/common/src/scalar/cache.rs @@ -20,10 +20,10 @@ use std::iter::repeat_n; use std::sync::{Arc, LazyLock, Mutex}; -use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, ArrayRef, PrimitiveArray, new_null_array}; use arrow::datatypes::{ - ArrowDictionaryKeyType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowDictionaryKeyType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; /// Maximum number of rows to cache to be conservative on memory usage diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs index 8cb446b1c9211..599c2523cd2c7 100644 --- a/datafusion/common/src/scalar/consts.rs +++ b/datafusion/common/src/scalar/consts.rs @@ -17,24 +17,36 @@ // Constants defined for scalar construction. +// Next F16 value above π (upper bound) +pub(super) const PI_UPPER_F16: half::f16 = half::f16::from_bits(0x4249); + // Next f32 value above π (upper bound) pub(super) const PI_UPPER_F32: f32 = std::f32::consts::PI.next_up(); // Next f64 value above π (upper bound) pub(super) const PI_UPPER_F64: f64 = std::f64::consts::PI.next_up(); +// Next f16 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F16: half::f16 = half::f16::from_bits(0xC249); + // Next f32 value below -π (lower bound) pub(super) const NEGATIVE_PI_LOWER_F32: f32 = (-std::f32::consts::PI).next_down(); // Next f64 value below -π (lower bound) pub(super) const NEGATIVE_PI_LOWER_F64: f64 = (-std::f64::consts::PI).next_down(); +// Next f16 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F16: half::f16 = half::f16::from_bits(0x3E49); + // Next f32 value above π/2 (upper bound) pub(super) const FRAC_PI_2_UPPER_F32: f32 = std::f32::consts::FRAC_PI_2.next_up(); // Next f64 value above π/2 (upper bound) pub(super) const FRAC_PI_2_UPPER_F64: f64 = std::f64::consts::FRAC_PI_2.next_up(); +// Next f32 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F16: half::f16 = half::f16::from_bits(0xBE49); + // Next f32 value below -π/2 (lower bound) pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = (-std::f32::consts::FRAC_PI_2).next_down(); diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 188a169a3dd2f..d726b5c94016f 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -26,6 +26,7 @@ use std::cmp::Ordering; use std::collections::{HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; +use std::fmt::Write; use std::hash::Hash; use std::hash::Hasher; use std::iter::repeat_n; @@ -33,64 +34,163 @@ use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; +use crate::assert_or_internal_err; use crate::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array, - as_date64_array, as_decimal128_array, as_decimal256_array, as_decimal32_array, - as_decimal64_array, as_dictionary_array, as_duration_microsecond_array, + as_date64_array, as_decimal32_array, as_decimal64_array, as_decimal128_array, + as_decimal256_array, as_dictionary_array, as_duration_microsecond_array, as_duration_millisecond_array, as_duration_nanosecond_array, as_duration_second_array, as_fixed_size_binary_array, as_fixed_size_list_array, - as_float16_array, as_float32_array, as_float64_array, as_int16_array, as_int32_array, - as_int64_array, as_int8_array, as_interval_dt_array, as_interval_mdn_array, + as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, + as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, as_large_list_array, - as_large_string_array, as_string_array, as_string_view_array, - as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, - as_time64_nanosecond_array, as_timestamp_microsecond_array, - as_timestamp_millisecond_array, as_timestamp_nanosecond_array, - as_timestamp_second_array, as_uint16_array, as_uint32_array, as_uint64_array, - as_uint8_array, as_union_array, + as_large_list_view_array, as_large_string_array, as_list_view_array, as_run_array, + as_string_array, as_string_view_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, as_uint8_array, + as_uint16_array, as_uint32_array, as_uint64_array, as_union_array, }; -use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; +use crate::error::{_exec_err, _internal_err, _not_impl_err, DataFusionError, Result}; use crate::format::DEFAULT_CAST_OPTIONS; use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - new_empty_array, new_null_array, Array, ArrayData, ArrayRef, ArrowNativeTypeOp, - ArrowPrimitiveType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + Array, ArrayData, ArrayDataBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + AsArray, BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, - Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, - LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, - PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, UnionArray, + GenericListViewArray, Int8Array, Int16Array, Int32Array, Int64Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeListArray, LargeListViewArray, LargeStringArray, ListArray, + ListViewArray, MapArray, MutableArrayData, PrimitiveArray, RunArray, Scalar, + StringArray, StringViewArray, StringViewBuilder, StructArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, UnionArray, + downcast_run_array, new_empty_array, new_null_array, }; use arrow::buffer::{BooleanBuffer, ScalarBuffer}; -use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; +use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; use arrow::compute::kernels::numeric::{ add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, }; use arrow::datatypes::{ - i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType, - ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, - Decimal32Type, Decimal64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionFields, - UnionMode, DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, + FieldRef, Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, RunEndIndexType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UnionFields, UnionMode, i256, + validate_decimal_precision_and_scale, }; -use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; use chrono::{Duration, NaiveDate}; use half::f16; pub use struct_builder::ScalarStructBuilder; +const SECONDS_PER_DAY: i64 = 86_400; +const MILLIS_PER_DAY: i64 = SECONDS_PER_DAY * 1_000; +const MICROS_PER_DAY: i64 = MILLIS_PER_DAY * 1_000; +const NANOS_PER_DAY: i64 = MICROS_PER_DAY * 1_000; +const MICROS_PER_MILLISECOND: i64 = 1_000; +const NANOS_PER_MILLISECOND: i64 = 1_000_000; + +/// Returns the multiplier that converts the input date representation into the +/// desired timestamp unit, if the conversion requires a multiplication that can +/// overflow an `i64`. +pub fn date_to_timestamp_multiplier( + source_type: &DataType, + target_type: &DataType, +) -> Option { + let DataType::Timestamp(target_unit, _) = target_type else { + return None; + }; + + // Only `Timestamp` target types have a time unit; otherwise no + // multiplier applies (handled above). The function returns `Some(m)` + // when converting the `source_type` to `target_type` requires a + // multiplication that could overflow `i64`. It returns `None` when + // the conversion is a division or otherwise doesn't require a + // multiplication (e.g. Date64 -> Second). + match source_type { + // Date32 stores days since epoch. Converting to any timestamp + // unit requires multiplying by the per-day factor (seconds, + // milliseconds, microseconds, nanoseconds). + DataType::Date32 => Some(match target_unit { + TimeUnit::Second => SECONDS_PER_DAY, + TimeUnit::Millisecond => MILLIS_PER_DAY, + TimeUnit::Microsecond => MICROS_PER_DAY, + TimeUnit::Nanosecond => NANOS_PER_DAY, + }), + + // Date64 stores milliseconds since epoch. Converting to + // seconds is a division (no multiplication), so return `None`. + // Converting to milliseconds is 1:1 (multiplier 1). Converting + // to micro/nano requires multiplying by 1_000 / 1_000_000. + DataType::Date64 => match target_unit { + TimeUnit::Second => None, + // Converting Date64 (ms since epoch) to millisecond timestamps + // is an identity conversion and does not require multiplication. + // Returning `None` indicates no multiplication-based overflow + // check is necessary. + TimeUnit::Millisecond => None, + TimeUnit::Microsecond => Some(MICROS_PER_MILLISECOND), + TimeUnit::Nanosecond => Some(NANOS_PER_MILLISECOND), + }, + + _ => None, + } +} + +/// Ensures the provided value can be represented as a timestamp with the given +/// multiplier. Returns an [`DataFusionError::Execution`] when the converted +/// value would overflow the timestamp range. +pub fn ensure_timestamp_in_bounds( + value: i64, + multiplier: i64, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + if multiplier <= 1 { + return Ok(()); + } + + if value.checked_mul(multiplier).is_none() { + let target = format_timestamp_type_for_error(target_type); + _exec_err!( + "Cannot cast {} value {} to {}: converted value exceeds the representable i64 range", + source_type, + value, + target + ) + } else { + Ok(()) + } +} + +/// Format a `DataType::Timestamp` into a short, stable string used in +/// user-facing error messages. +pub(crate) fn format_timestamp_type_for_error(target_type: &DataType) -> String { + match target_type { + DataType::Timestamp(unit, _) => { + let s = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "ms", + TimeUnit::Microsecond => "us", + TimeUnit::Nanosecond => "ns", + }; + format!("Timestamp({s})") + } + other => format!("{other}"), + } +} + /// A dynamically typed, nullable single value. /// /// While an arrow [`Array`]) stores one or more values of the same type, in a @@ -158,8 +258,8 @@ pub use struct_builder::ScalarStructBuilder; /// /// # Nested Types /// -/// `List` / `LargeList` / `FixedSizeList` / `Struct` / `Map` are represented as a -/// single element array of the corresponding type. +/// `List` / `LargeList` / `FixedSizeList` / `ListView` / `LargeListView` / `Struct` / `Map` +/// are represented as a single element array of the corresponding type. /// /// ## Example: Creating [`ScalarValue::Struct`] using [`ScalarStructBuilder`] /// ``` @@ -282,6 +382,14 @@ pub enum ScalarValue { List(Arc), /// The array must be a LargeListArray with length 1. LargeList(Arc), + /// Represents a single element of a [`ListViewArray`] as an [`ArrayRef`] + /// + /// The array must be a ListViewArray with length 1. + ListView(Arc), + /// Represents a single element of a [`LargeListViewArray`] as an [`ArrayRef`] + /// + /// The array must be a LargeListViewArray with length 1. + LargeListView(Arc), /// Represents a single element [`StructArray`] as an [`ArrayRef`]. See /// [`ScalarValue`] for examples of how to create instances of this type. Struct(Arc), @@ -331,6 +439,8 @@ pub enum ScalarValue { Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), + /// (run-ends field, value field, value) + RunEndEncoded(FieldRef, FieldRef, Box), } impl Hash for Fl { @@ -416,6 +526,10 @@ impl PartialEq for ScalarValue { (List(_), _) => false, (LargeList(v1), LargeList(v2)) => v1.eq(v2), (LargeList(_), _) => false, + (ListView(v1), ListView(v2)) => v1.eq(v2), + (ListView(_), _) => false, + (LargeListView(v1), LargeListView(v2)) => v1.eq(v2), + (LargeListView(_), _) => false, (Struct(v1), Struct(v2)) => v1.eq(v2), (Struct(_), _) => false, (Map(v1), Map(v2)) => v1.eq(v2), @@ -460,6 +574,10 @@ impl PartialEq for ScalarValue { (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + rf1.eq(rf2) && vf1.eq(vf2) && v1.eq(v2) + } + (RunEndEncoded(_, _, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -557,7 +675,8 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList / ScalarValue::ListView / ScalarValue::LargeListView + // are guaranteed to have length 1 (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), (FixedSizeList(arr1), FixedSizeList(arr2)) => { partial_cmp_list(arr1.as_ref(), arr2.as_ref()) @@ -565,7 +684,17 @@ impl PartialOrd for ScalarValue { (LargeList(arr1), LargeList(arr2)) => { partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } - (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, + (ListView(arr1), ListView(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) + } + (LargeListView(arr1), LargeListView(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) + } + (List(_), _) + | (LargeList(_), _) + | (FixedSizeList(_), _) + | (ListView(_), _) + | (LargeListView(_), _) => None, (Struct(struct_arr1), Struct(struct_arr2)) => { partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref()) } @@ -622,20 +751,25 @@ impl PartialOrd for ScalarValue { (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { + if k1 == k2 { v1.partial_cmp(v2) } else { None } + } + (Dictionary(_, _), _) => None, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + // Don't compare if the run ends fields don't match (it is effectively a different datatype) + if rf1 == rf2 && vf1 == vf2 { v1.partial_cmp(v2) } else { None } } - (Dictionary(_, _), _) => None, + (RunEndEncoded(_, _, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } } } -/// List/LargeList/FixedSizeList scalars always have a single element +/// List/LargeList/FixedSizeList/ListView/LargeListView scalars always have a single element /// array. This function returns that array fn first_array_for_list(arr: &dyn Array) -> ArrayRef { assert_eq!(arr.len(), 1); @@ -645,12 +779,18 @@ fn first_array_for_list(arr: &dyn Array) -> ArrayRef { arr.value(0) } else if let Some(arr) = arr.as_fixed_size_list_opt() { arr.value(0) + } else if let Some(arr) = arr.as_list_view_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_view_opt::() { + arr.value(0) } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + unreachable!( + "Since only List / LargeList / FixedSizeList / ListView / LargeListView are supported, this should never happen" + ) } } -/// Compares two List/LargeList/FixedSizeList scalars +/// Compares two List/LargeList/FixedSizeList/ListView/LargeListView scalars fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { if arr1.data_type() != arr2.data_type() { return None; @@ -838,6 +978,12 @@ impl Hash for ScalarValue { FixedSizeList(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } + ListView(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); + } + LargeListView(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); + } Struct(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } @@ -870,6 +1016,11 @@ impl Hash for ScalarValue { k.hash(state); v.hash(state); } + RunEndEncoded(rf, vf, v) => { + rf.hash(state); + vf.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } @@ -878,10 +1029,10 @@ impl Hash for ScalarValue { fn hash_nested_array(arr: ArrayRef, state: &mut H) { let len = arr.len(); - let arrays = vec![arr]; let hashes_buffer = &mut vec![0; len]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + let random_state = crate::hash_utils::RandomState::with_seed(0); + let hashes = create_hashes(&[arr], &random_state, hashes_buffer) + .expect("hash_nested_array: failed to create row hashes"); // Hash back to std::hash::Hasher hashes.hash(state); } @@ -1055,13 +1206,8 @@ impl ScalarValue { /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result { - // make sure the precision and scale is valid - if precision <= DECIMAL128_MAX_PRECISION && scale.unsigned_abs() <= precision { - return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); - } - _internal_err!( - "Can not new a decimal type ScalarValue for precision {precision} and scale {scale}" - ) + Self::validate_decimal_or_internal_err::(precision, scale)?; + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) } /// Create a Null instance of ScalarValue for this datatype @@ -1153,7 +1299,14 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - // `ScalaValue::List` contains single element `ListArray`. + DataType::RunEndEncoded(run_ends_field, value_field) => { + ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(value_field.data_type().try_into()?), + ) + } + // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), @@ -1161,7 +1314,7 @@ impl ScalarValue { DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), - // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + // `ScalarValue::FixedSizeList` contains single element `FixedSizeList`. DataType::FixedSizeList(field_ref, fixed_length) => { ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( Arc::clone(field_ref), @@ -1169,6 +1322,12 @@ impl ScalarValue { 1, ))) } + DataType::ListView(field_ref) => ScalarValue::ListView(Arc::new( + GenericListViewArray::new_null(Arc::clone(field_ref), 1), + )), + DataType::LargeListView(field_ref) => ScalarValue::LargeListView(Arc::new( + GenericListViewArray::new_null(Arc::clone(field_ref), 1), + )), DataType::Struct(fields) => ScalarValue::Struct( new_null_array(&DataType::Struct(fields.to_owned()), 1) .as_struct() @@ -1241,6 +1400,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI pub fn new_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::PI)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), _ => _internal_err!("PI is not supported for data type: {}", datatype), @@ -1250,6 +1410,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI's upper bound pub fn new_pi_upper(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some(consts::PI_UPPER_F16))), DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), _ => { @@ -1261,6 +1422,9 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI's lower bound pub fn new_negative_pi_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::NEGATIVE_PI_LOWER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), _ => { @@ -1272,6 +1436,9 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::FRAC_PI_2_UPPER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), _ => { @@ -1283,6 +1450,9 @@ impl ScalarValue { // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some( + consts::NEGATIVE_FRAC_PI_2_LOWER_F16, + ))), DataType::Float32 => { Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32)) } @@ -1298,6 +1468,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI pub fn new_negative_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::PI)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), _ => _internal_err!("-PI is not supported for data type: {}", datatype), @@ -1307,6 +1478,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI/2 pub fn new_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), _ => _internal_err!("PI/2 is not supported for data type: {}", datatype), @@ -1316,6 +1488,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI/2 pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), _ => _internal_err!("-PI/2 is not supported for data type: {}", datatype), @@ -1325,6 +1498,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing infinity pub fn new_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), _ => { @@ -1336,6 +1510,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing negative infinity pub fn new_neg_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::NEG_INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), _ => { @@ -1359,7 +1534,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ZERO)), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), DataType::Decimal32(precision, scale) => { @@ -1467,6 +1642,8 @@ impl ScalarValue { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Timestamp(_, _) @@ -1503,7 +1680,7 @@ impl ScalarValue { let empty_arr = new_empty_array(field.data_type()); let values = Arc::new( SingleRowListArrayBuilder::new(empty_arr) - .with_nullable(field.is_nullable()) + .with_field(field) .build_fixed_size_list_array(0), ); Ok(ScalarValue::FixedSizeList(values)) @@ -1512,6 +1689,24 @@ impl ScalarValue { let list = ScalarValue::new_large_list(&[], field.data_type()); Ok(ScalarValue::LargeList(list)) } + DataType::ListView(field) => { + let empty_arr = new_empty_array(field.data_type()); + let values = Arc::new( + SingleRowListArrayBuilder::new(empty_arr) + .with_field(field) + .build_list_view_array(), + ); + Ok(ScalarValue::ListView(values)) + } + DataType::LargeListView(field) => { + let empty_arr = new_empty_array(field.data_type()); + let values = Arc::new( + SingleRowListArrayBuilder::new(empty_arr) + .with_field(field) + .build_large_list_view_array(), + ); + Ok(ScalarValue::LargeListView(values)) + } // Struct types DataType::Struct(fields) => { @@ -1535,6 +1730,14 @@ impl ScalarValue { Box::new(ScalarValue::new_default(value_type)?), )), + DataType::RunEndEncoded(run_ends_field, value_field) => { + Ok(ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(ScalarValue::new_default(value_field.data_type())?), + )) + } + // Map types DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( ArrayData::new_empty(field.data_type()), @@ -1553,13 +1756,6 @@ impl ScalarValue { _internal_err!("Union type must have at least one field") } } - - // Unsupported types for now - _ => { - _not_impl_err!( - "Default value for data_type \"{datatype}\" is not implemented yet" - ) - } } } @@ -1574,16 +1770,14 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ONE)), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), DataType::Decimal32(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal32(Some(value), *precision, *scale) @@ -1592,12 +1786,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal64(Some(value), *precision, *scale) @@ -1606,12 +1798,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal128(Some(value), *precision, *scale) @@ -1620,12 +1810,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal256(Some(value), *precision, *scale) @@ -1648,16 +1836,14 @@ impl ScalarValue { DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::NEG_ONE)), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), DataType::Decimal32(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal32(Some(-value), *precision, *scale) @@ -1666,12 +1852,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal64(Some(-value), *precision, *scale) @@ -1680,12 +1864,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal128(Some(-value), *precision, *scale) @@ -1694,12 +1876,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal256(Some(-value), *precision, *scale) @@ -1729,14 +1909,10 @@ impl ScalarValue { DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), DataType::Decimal32(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale <= 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal32(Some(value), *precision, *scale) @@ -1745,14 +1921,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale <= 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal64(Some(value), *precision, *scale) @@ -1761,14 +1933,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal128(Some(value), *precision, *scale) @@ -1777,14 +1945,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal256(Some(value), *precision, *scale) @@ -1849,6 +2013,8 @@ impl ScalarValue { ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), + ScalarValue::ListView(arr) => arr.data_type().to_owned(), + ScalarValue::LargeListView(arr) => arr.data_type().to_owned(), ScalarValue::Struct(arr) => arr.data_type().to_owned(), ScalarValue::Map(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, @@ -1878,10 +2044,219 @@ impl ScalarValue { ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } + ScalarValue::RunEndEncoded(run_ends_field, value_field, _) => { + DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + ) + } ScalarValue::Null => DataType::Null, } } + #[inline] + fn can_use_direct_add(lhs: &ScalarValue, rhs: &ScalarValue) -> bool { + matches!( + (lhs, rhs), + (ScalarValue::Int8(_), ScalarValue::Int8(_)) + | (ScalarValue::Int16(_), ScalarValue::Int16(_)) + | (ScalarValue::Int32(_), ScalarValue::Int32(_)) + | (ScalarValue::Int64(_), ScalarValue::Int64(_)) + | (ScalarValue::UInt8(_), ScalarValue::UInt8(_)) + | (ScalarValue::UInt16(_), ScalarValue::UInt16(_)) + | (ScalarValue::UInt32(_), ScalarValue::UInt32(_)) + | (ScalarValue::UInt64(_), ScalarValue::UInt64(_)) + | (ScalarValue::Float16(_), ScalarValue::Float16(_)) + | (ScalarValue::Float32(_), ScalarValue::Float32(_)) + | (ScalarValue::Float64(_), ScalarValue::Float64(_)) + | ( + ScalarValue::Decimal32(_, _, _), + ScalarValue::Decimal32(_, _, _) + ) + | ( + ScalarValue::Decimal64(_, _, _), + ScalarValue::Decimal64(_, _, _) + ) + | ( + ScalarValue::Decimal128(_, _, _), + ScalarValue::Decimal128(_, _, _), + ) + | ( + ScalarValue::Decimal256(_, _, _), + ScalarValue::Decimal256(_, _, _), + ) + ) + } + + #[inline] + fn add_optional( + lhs: &mut Option, + rhs: Option, + checked: bool, + ) -> Result<()> { + match rhs { + Some(rhs) => { + if let Some(lhs) = lhs.as_mut() { + *lhs = if checked { + lhs.add_checked(rhs).map_err(|e| arrow_datafusion_err!(e))? + } else { + lhs.add_wrapping(rhs) + }; + } + } + None => *lhs = None, + } + Ok(()) + } + + #[inline] + fn add_decimal_values( + lhs_value: &mut Option, + lhs_precision: &mut u8, + lhs_scale: &mut i8, + rhs_value: Option, + rhs_precision: u8, + rhs_scale: i8, + ) -> Result<()> + where + T::Native: ArrowNativeTypeOp, + { + Self::validate_decimal_or_internal_err::(*lhs_precision, *lhs_scale)?; + Self::validate_decimal_or_internal_err::(rhs_precision, rhs_scale)?; + + let result_scale = (*lhs_scale).max(rhs_scale); + // Decimal scales can be negative, so use a wider signed type for the + // intermediate precision arithmetic. + let lhs_precision_delta = i16::from(*lhs_precision) - i16::from(*lhs_scale); + let rhs_precision_delta = i16::from(rhs_precision) - i16::from(rhs_scale); + let result_precision = + (i16::from(result_scale) + lhs_precision_delta.max(rhs_precision_delta) + 1) + .min(i16::from(T::MAX_PRECISION)) as u8; + + Self::validate_decimal_or_internal_err::(result_precision, result_scale)?; + + let lhs_mul = T::Native::usize_as(10) + .pow_checked((result_scale - *lhs_scale) as u32) + .map_err(|e| arrow_datafusion_err!(e))?; + let rhs_mul = T::Native::usize_as(10) + .pow_checked((result_scale - rhs_scale) as u32) + .map_err(|e| arrow_datafusion_err!(e))?; + + let result_value = match (*lhs_value, rhs_value) { + (Some(lhs_value), Some(rhs_value)) => Some( + lhs_value + .mul_checked(lhs_mul) + .and_then(|lhs| { + rhs_value + .mul_checked(rhs_mul) + .and_then(|rhs| lhs.add_checked(rhs)) + }) + .map_err(|e| arrow_datafusion_err!(e))?, + ), + _ => None, + }; + + *lhs_value = result_value; + *lhs_precision = result_precision; + *lhs_scale = result_scale; + + Ok(()) + } + + #[inline] + fn try_add_in_place_impl( + &mut self, + other: &ScalarValue, + checked: bool, + ) -> Result { + match (self, other) { + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + Self::add_optional(lhs, *rhs, checked)?; + } + ( + ScalarValue::Decimal32(lhs, p, s), + ScalarValue::Decimal32(rhs, rhs_p, rhs_s), + ) => { + Self::add_decimal_values::( + lhs, p, s, *rhs, *rhs_p, *rhs_s, + )?; + } + ( + ScalarValue::Decimal64(lhs, p, s), + ScalarValue::Decimal64(rhs, rhs_p, rhs_s), + ) => { + Self::add_decimal_values::( + lhs, p, s, *rhs, *rhs_p, *rhs_s, + )?; + } + ( + ScalarValue::Decimal128(lhs, p, s), + ScalarValue::Decimal128(rhs, rhs_p, rhs_s), + ) => { + Self::add_decimal_values::( + lhs, p, s, *rhs, *rhs_p, *rhs_s, + )?; + } + ( + ScalarValue::Decimal256(lhs, p, s), + ScalarValue::Decimal256(rhs, rhs_p, rhs_s), + ) => { + Self::add_decimal_values::( + lhs, p, s, *rhs, *rhs_p, *rhs_s, + )?; + } + _ => return Ok(false), + } + + Ok(true) + } + + #[inline] + pub(crate) fn try_add_wrapping_in_place( + &mut self, + other: &ScalarValue, + ) -> Result { + self.try_add_in_place_impl(other, false) + } + + #[inline] + pub(crate) fn try_add_checked_in_place( + &mut self, + other: &ScalarValue, + ) -> Result { + self.try_add_in_place_impl(other, true) + } + /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { fn neg_checked_with_ctx( @@ -1899,9 +2274,7 @@ impl ScalarValue { | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), - ScalarValue::Float16(Some(v)) => { - Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) - } + ScalarValue::Float16(Some(v)) => Ok(ScalarValue::Float16(Some(-v))), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), @@ -2019,15 +2392,34 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add>(&self, other: T) -> Result { - let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + let other = other.borrow(); + if Self::can_use_direct_add(self, other) { + let mut result = self.clone(); + if result.try_add_wrapping_in_place(other)? { + return Ok(result); + } + debug_assert!(false, "fast-path eligibility drifted from implementation"); + } + + let r = add_wrapping(&self.to_scalar()?, &other.to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } + /// Checked addition of `ScalarValue` /// /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add_checked>(&self, other: T) -> Result { - let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + let other = other.borrow(); + if Self::can_use_direct_add(self, other) { + let mut result = self.clone(); + if result.try_add_checked_in_place(other)? { + return Ok(result); + } + debug_assert!(false, "fast-path eligibility drifted from implementation"); + } + + let r = add(&self.to_scalar()?, &other.to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -2133,6 +2525,8 @@ impl ScalarValue { ScalarValue::List(arr) => arr.len() == arr.null_count(), ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), + ScalarValue::ListView(arr) => arr.len() == arr.null_count(), + ScalarValue::LargeListView(arr) => arr.len() == arr.null_count(), ScalarValue::Struct(arr) => arr.len() == arr.null_count(), ScalarValue::Map(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), @@ -2157,6 +2551,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::RunEndEncoded(_, _, v) => v.is_null(), } } @@ -2187,6 +2582,25 @@ impl ScalarValue { (Self::Float64(Some(l)), Self::Float64(Some(r))) => { Some((l - r).abs().round() as _) } + (Self::Date32(Some(l)), Self::Date32(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Date64(Some(l)), Self::Date64(Some(r))) => Some(l.abs_diff(*r) as _), + // Timestamp values are stored as epoch ticks regardless of timezone + // annotation, so the distance is tz-independent (tz is display metadata). + (Self::TimestampSecond(Some(l), _), Self::TimestampSecond(Some(r), _)) => { + Some(l.abs_diff(*r) as _) + } + ( + Self::TimestampMillisecond(Some(l), _), + Self::TimestampMillisecond(Some(r), _), + ) => Some(l.abs_diff(*r) as _), + ( + Self::TimestampMicrosecond(Some(l), _), + Self::TimestampMicrosecond(Some(r), _), + ) => Some(l.abs_diff(*r) as _), + ( + Self::TimestampNanosecond(Some(l), _), + Self::TimestampNanosecond(Some(r), _), + ) => Some(l.abs_diff(*r) as _), ( Self::Decimal128(Some(l), lprecision, lscale), Self::Decimal128(Some(r), rprecision, rscale), @@ -2293,18 +2707,20 @@ impl ScalarValue { macro_rules! build_array_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; @@ -2313,18 +2729,20 @@ impl ScalarValue { macro_rules! build_array_primitive_tz { ($ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v, _) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array.with_timezone_opt($TZ.clone())) } }}; @@ -2335,18 +2753,20 @@ impl ScalarValue { macro_rules! build_array_string { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; @@ -2477,6 +2897,8 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { @@ -2518,6 +2940,94 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + DataType::RunEndEncoded(run_ends_field, value_field) => { + fn make_run_array( + scalars: impl IntoIterator, + run_ends_field: &FieldRef, + values_field: &FieldRef, + ) -> Result { + let mut scalars = scalars.into_iter(); + + let mut run_ends = vec![]; + let mut value_scalars = vec![]; + + let mut len = R::Native::ONE; + let mut current = + if let Some(ScalarValue::RunEndEncoded(_, _, scalar)) = + scalars.next() + { + *scalar + } else { + // We are guaranteed to have one element of correct + // type because we peeked above + unreachable!() + }; + for scalar in scalars { + let scalar = match scalar { + ScalarValue::RunEndEncoded( + inner_run_ends_field, + inner_value_field, + scalar, + ) if &inner_run_ends_field == run_ends_field + && &inner_value_field == values_field => + { + *scalar + } + _ => { + return _exec_err!( + "Expected RunEndEncoded scalar with run-ends field {run_ends_field} but got: {scalar:?}" + ); + } + }; + + // new run + if scalar != current { + run_ends.push(len); + value_scalars.push(current); + current = scalar; + } + + len = len.add_checked(R::Native::ONE).map_err(|_| { + DataFusionError::Execution(format!( + "Cannot construct RunArray: Overflows run-ends type {}", + run_ends_field.data_type() + )) + })?; + } + + run_ends.push(len); + value_scalars.push(current); + + let run_ends = PrimitiveArray::::from_iter_values(run_ends); + let values = ScalarValue::iter_to_array(value_scalars)?; + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(RunArray::logical_len(&run_ends)) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + + match run_ends_field.data_type() { + DataType::Int16 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int32 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int64 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -2545,10 +3055,7 @@ impl ScalarValue { DataType::Time32(TimeUnit::Microsecond) | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) - | DataType::Time64(TimeUnit::Millisecond) - | DataType::RunEndEncoded(_, _) - | DataType::ListView(_) - | DataType::LargeListView(_) => { + | DataType::Time64(TimeUnit::Millisecond) => { return _not_impl_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, @@ -2648,71 +3155,6 @@ impl ScalarValue { Ok(array) } - fn build_decimal32_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal32Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal32Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal64_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal64Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal64Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal128_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal128Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal128Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal256_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(repeat_n(value, size) - .collect::() - .with_precision_and_scale(precision, scale)?) - } - /// Converts `Vec` where each element has type corresponding to /// `data_type`, to a single element [`ListArray`]. /// @@ -2863,23 +3305,40 @@ impl ScalarValue { /// /// Errors if `self` is /// - a decimal that fails be converted to a decimal array of size - /// - a `FixedsizeList` that fails to be concatenated into an array of size + /// - a `FixedSizeList` that fails to be concatenated into an array of size /// - a `List` that fails to be concatenated into an array of size /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { Ok(match self { - ScalarValue::Decimal32(e, precision, scale) => Arc::new( - ScalarValue::build_decimal32_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal32(Some(e), precision, scale) => Arc::new( + Decimal32Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal64(e, precision, scale) => Arc::new( - ScalarValue::build_decimal64_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal32(None, precision, scale) => { + new_null_array(&DataType::Decimal32(*precision, *scale), size) + } + ScalarValue::Decimal64(Some(e), precision, scale) => Arc::new( + Decimal64Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal128_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal64(None, precision, scale) => { + new_null_array(&DataType::Decimal64(*precision, *scale), size) + } + ScalarValue::Decimal128(Some(e), precision, scale) => Arc::new( + Decimal128Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal128(None, precision, scale) => { + new_null_array(&DataType::Decimal128(*precision, *scale), size) + } + ScalarValue::Decimal256(Some(e), precision, scale) => Arc::new( + Decimal256Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), + ScalarValue::Decimal256(None, precision, scale) => { + new_null_array(&DataType::Decimal256(*precision, *scale), size) + } + ScalarValue::Boolean(e) => match e { None => new_null_array(&DataType::Boolean, size), Some(true) => { @@ -2952,33 +3411,35 @@ impl ScalarValue { ) } ScalarValue::Utf8(e) => match e { - Some(value) => { - Arc::new(StringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(StringArray::new_repeated(value, size)), None => new_null_array(&DataType::Utf8, size), }, ScalarValue::Utf8View(e) => match e { Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat_n(value, size))) + let mut builder = StringViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; + let array = builder.finish(); + Arc::new(array) } None => new_null_array(&DataType::Utf8View, size), }, ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(LargeStringArray::new_repeated(value, size)), None => new_null_array(&DataType::LargeUtf8, size), }, ScalarValue::Binary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + Arc::new(BinaryArray::new_repeated(value.as_slice(), size)) + } None => new_null_array(&DataType::Binary, size), }, ScalarValue::BinaryView(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + let mut builder = BinaryViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; + let array = builder.finish(); + Arc::new(array) + } None => new_null_array(&DataType::BinaryView, size), }, ScalarValue::FixedSizeBinary(s, e) => match e { @@ -2992,9 +3453,9 @@ impl ScalarValue { None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), }, ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + Arc::new(LargeBinaryArray::new_repeated(value.as_slice(), size)) + } None => new_null_array(&DataType::LargeBinary, size), }, ScalarValue::List(arr) => { @@ -3015,6 +3476,18 @@ impl ScalarValue { } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } + ScalarValue::ListView(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } + ScalarValue::LargeListView(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } ScalarValue::Struct(arr) => { if size == 1 { return Ok(Arc::clone(arr) as Arc); @@ -3153,10 +3626,7 @@ impl ScalarValue { .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; Arc::new(ar) } - None => { - let dt = self.data_type(); - new_null_array(&dt, size) - } + None => new_null_array(&DataType::Union(fields.clone(), *mode), size), }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) @@ -3172,6 +3642,54 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + ScalarValue::RunEndEncoded(run_ends_field, values_field, value) => { + fn make_run_array( + run_ends_field: &Arc, + values_field: &Arc, + value: &ScalarValue, + size: usize, + ) -> Result { + let size_native = R::Native::from_usize(size) + .ok_or_else(|| DataFusionError::Execution(format!("Cannot construct RunArray of size {size}: Overflows run-ends type {}", R::DATA_TYPE)))?; + let values = value.to_array_of_size(1)?; + let run_ends = + PrimitiveArray::::new(vec![size_native].into(), None); + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(size) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + match run_ends_field.data_type() { + DataType::Int16 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int32 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int64 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -3225,13 +3743,22 @@ impl ScalarValue { } } + /// Repeats the rows of `arr` `size` times, producing an array with + /// `arr.len() * size` total rows. fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = repeat_n(arr, size).collect::>(); - let ret = match !arrays.is_empty() { - true => arrow::compute::concat(arrays.as_slice())?, - false => arr.slice(0, 0), - }; - Ok(ret) + if size == 0 { + return Ok(arr.slice(0, 0)); + } + + // Examples: given `arr = [[A, B, C]]` and `size = 3`, `indices = [0, 0, 0]` and + // the result is `[[A, B, C], [A, B, C], [A, B, C]]`. + // + // Given `arr = [[A, B], [C]]` and `size = 2`, `indices = [0, 1, 0, 1]` and the + // result is `[[A, B], [C], [A, B], [C]]`. (But in practice, we are always called + // with `arr.len() == 1`.) + let n = arr.len() as u32; + let indices = UInt32Array::from_iter_values((0..size).flat_map(|_| 0..n)); + Ok(arrow::compute::take(arr, &indices, None)?) } /// Retrieve ScalarValue for each row in `array` @@ -3338,29 +3865,35 @@ impl ScalarValue { pub fn convert_array_to_scalar_vec( array: &dyn Array, ) -> Result>>> { - fn generic_collect( - array: &dyn Array, - ) -> Result>>> { - array - .as_list::() - .iter() - .map(|nested_array| { - nested_array - .map(|array| { - (0..array.len()) - .map(|i| ScalarValue::try_from_array(&array, i)) - .collect::>>() - }) - .transpose() + fn map_element( + nested_array: Option, + ) -> Result>> { + nested_array + .map(|array| { + (0..array.len()) + .map(|i| ScalarValue::try_from_array(&array, i)) + .collect::>>() }) - .collect() + .transpose() } match array.data_type() { - DataType::List(_) => generic_collect::(array), - DataType::LargeList(_) => generic_collect::(array), + DataType::List(_) => array.as_list::().iter().map(map_element).collect(), + DataType::LargeList(_) => { + array.as_list::().iter().map(map_element).collect() + } + DataType::ListView(_) => array + .as_list_view::() + .iter() + .map(map_element) + .collect(), + DataType::LargeListView(_) => array + .as_list_view::() + .iter() + .map(map_element) + .collect(), _ => _internal_err!( - "ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList type" + "ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList/ListView/LargeListView type" ), } } @@ -3379,7 +3912,7 @@ impl ScalarValue { /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value - if !array.is_valid(index) { + if array.is_null(index) { return array.data_type().try_into(); } @@ -3457,6 +3990,22 @@ impl ScalarValue { .with_field(field) .build_fixed_size_list_scalar(list_size) } + DataType::ListView(field) => { + let list_array = as_list_view_array(array)?; + let nested_array = list_array.value(index); + // Produces a single element `ListViewArray` with the value at `index`. + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_list_view_scalar() + } + DataType::LargeListView(field) => { + let list_array = as_large_list_view_array(array)?; + let nested_array = list_array.value(index); + // Produces a single element `LargeListViewArray` with the value at `index`. + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_large_list_view_scalar() + } DataType::Date32 => typed_cast!(array, index, as_date32_array, Date32)?, DataType::Date64 => typed_cast!(array, index, as_date64_array, Date64)?, DataType::Time32(TimeUnit::Second) => { @@ -3522,6 +4071,28 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } + DataType::RunEndEncoded(run_ends_field, value_field) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + let scalar = downcast_run_array!( + array => { + let index = array.get_physical_index(index); + ScalarValue::try_from_array(array.values(), index)? + }, + dt => unreachable!("Invalid run-ends type: {dt}") + ); + Self::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(scalar), + ) + } DataType::Struct(_) => { let a = array.slice(index, 1); Self::Struct(Arc::new(a.as_struct().to_owned())) @@ -3634,6 +4205,7 @@ impl ScalarValue { ScalarValue::LargeUtf8(v) => v, ScalarValue::Utf8View(v) => v, ScalarValue::Dictionary(_, v) => return v.try_as_str(), + ScalarValue::RunEndEncoded(_, _, v) => return v.try_as_str(), _ => return None, }; Some(v.as_ref().map(|v| v.as_str())) @@ -3650,11 +4222,38 @@ impl ScalarValue { target_type: &DataType, cast_options: &CastOptions<'static>, ) -> Result { + let source_type = self.data_type(); + if let Some(multiplier) = date_to_timestamp_multiplier(&source_type, target_type) + && let Some(value) = self.date_scalar_value_as_i64() + { + ensure_timestamp_in_bounds(value, multiplier, &source_type, target_type)?; + } + let scalar_array = self.to_array()?; - let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; + + // For types that contain structs (including nested inside Lists, Dictionaries, + // etc.), use name-based casting logic that matches struct fields by name and + // recursively casts nested structs. + let cast_arr = if crate::nested_struct::requires_nested_struct_cast( + scalar_array.data_type(), + target_type, + ) { + crate::nested_struct::cast_column(&scalar_array, target_type, cast_options)? + } else { + cast_with_options(&scalar_array, target_type, cast_options)? + }; + ScalarValue::try_from_array(&cast_arr, 0) } + fn date_scalar_value_as_i64(&self) -> Option { + match self { + ScalarValue::Date32(Some(value)) => Some(i64::from(*value)), + ScalarValue::Date64(Some(value)) => Some(*value), + _ => None, + } + } + fn eq_array_decimal32( array: &ArrayRef, index: usize, @@ -3862,6 +4461,12 @@ impl ScalarValue { ScalarValue::FixedSizeList(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } + ScalarValue::ListView(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } + ScalarValue::LargeListView(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } ScalarValue::Struct(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } @@ -3947,6 +4552,34 @@ impl ScalarValue { None => v.is_null(), } } + ScalarValue::RunEndEncoded(run_ends_field, _, value) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + match run_ends_field.data_type() { + DataType::Int16 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int32 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int64 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => array.is_null(index), }) } @@ -4021,6 +4654,8 @@ impl ScalarValue { ScalarValue::List(arr) => arr.get_array_memory_size(), ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::ListView(arr) => arr.get_array_memory_size(), + ScalarValue::LargeListView(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { @@ -4036,6 +4671,7 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::RunEndEncoded(rf, vf, v) => rf.size() + vf.size() + v.size(), } } @@ -4066,6 +4702,7 @@ impl ScalarValue { /// Estimates [size](Self::size) of [`HashSet`] in bytes. /// /// Includes the size of the [`HashSet`] container itself. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key pub fn size_of_hashset(set: &HashSet) -> usize { size_of_val(set) + (size_of::() * set.capacity()) @@ -4135,6 +4772,14 @@ impl ScalarValue { let array = copy_array_data(&arr.to_data()); *Arc::make_mut(arr) = LargeListArray::from(array) } + ScalarValue::ListView(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = ListViewArray::from(array); + } + ScalarValue::LargeListView(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = LargeListViewArray::from(array) + } ScalarValue::Struct(arr) => { let array = copy_array_data(&arr.to_data()); *Arc::make_mut(arr) = StructArray::from(array); @@ -4151,6 +4796,9 @@ impl ScalarValue { ScalarValue::Dictionary(_, value) => { value.compact(); } + ScalarValue::RunEndEncoded(_, _, value) => { + value.compact(); + } } } @@ -4354,6 +5002,20 @@ impl ScalarValue { _ => None, } } + + /// A thin wrapper on Arrow's validation that throws internal error if validation + /// fails. + fn validate_decimal_or_internal_err( + precision: u8, + scale: i8, + ) -> Result<()> { + validate_decimal_precision_and_scale::(precision, scale).map_err(|err| { + _internal_datafusion_err!( + "Decimal precision/scale invariant violated \ + (precision={precision}, scale={scale}): {err}" + ) + }) + } } /// Compacts the data of an `ArrayData` into a new `ArrayData`. @@ -4407,6 +5069,7 @@ macro_rules! impl_scalar { impl_scalar!(f64, Float64); impl_scalar!(f32, Float32); +impl_scalar!(f16, Float16); impl_scalar!(i8, Int8); impl_scalar!(i16, Int16); impl_scalar!(i32, Int32); @@ -4563,6 +5226,7 @@ impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); impl_try_from!(UInt64, u64); +impl_try_from!(Float16, f16); impl_try_from!(Float32, f32); impl_try_from!(Float64, f64); impl_try_from!(Boolean, bool); @@ -4639,8 +5303,10 @@ impl fmt::Display for ScalarValue { | ScalarValue::BinaryView(e) => match e { Some(bytes) => { // print up to first 10 bytes, with trailing ... if needed + const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; for b in bytes.iter().take(10) { - write!(f, "{b:02X}")?; + f.write_char(HEX_CHARS_UPPER[(b >> 4) as usize] as char)?; + f.write_char(HEX_CHARS_UPPER[(b & 0x0f) as usize] as char)?; } if bytes.len() > 10 { write!(f, "...")?; @@ -4648,9 +5314,11 @@ impl fmt::Display for ScalarValue { } None => write!(f, "NULL")?, }, - ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::List(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::LargeList(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::FixedSizeList(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::ListView(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::LargeListView(arr) => fmt_list(arr.as_ref(), f)?, ScalarValue::Date32(e) => format_option!( f, e.map(|v| { @@ -4766,18 +5434,18 @@ impl fmt::Display for ScalarValue { None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, + ScalarValue::RunEndEncoded(_, _, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } } -fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { - // ScalarValue List, LargeList, FixedSizeList should always have a single element +fn fmt_list(arr: &dyn Array, f: &mut fmt::Formatter) -> fmt::Result { + // ScalarValue List, LargeList, FixedSizeList, ListView, LargeListView should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = - ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}") } @@ -4860,6 +5528,8 @@ impl fmt::Debug for ScalarValue { ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), + ScalarValue::ListView(_) => write!(f, "ListView({self})"), + ScalarValue::LargeListView(_) => write!(f, "LargeListView({self})"), ScalarValue::Struct(struct_arr) => { // ScalarValue Struct should always have a single element assert_eq!(struct_arr.len(), 1); @@ -4945,6 +5615,9 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), + ScalarValue::RunEndEncoded(rf, vf, v) => { + write!(f, "RunEndEncoded({rf:?}, {vf:?}, {v:?})") + } ScalarValue::Null => write!(f, "NULL"), } } @@ -4994,24 +5667,26 @@ impl ScalarType for Date32Type { #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; - use crate::cast::{as_list_array, as_map_array, as_struct_array}; + use crate::cast::{ + as_large_list_view_array, as_list_array, as_map_array, as_struct_array, + }; use crate::test_util::batches_to_string; use arrow::array::{ - FixedSizeListBuilder, Int32Builder, LargeListBuilder, ListBuilder, MapBuilder, - NullArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveBuilder, RecordBatch, - StringBuilder, StringDictionaryBuilder, StructBuilder, UnionBuilder, + FixedSizeListBuilder, Int32Builder, LargeListBuilder, LargeListViewBuilder, + ListBuilder, ListViewBuilder, MapBuilder, NullArray, NullBufferBuilder, + OffsetSizeTrait, PrimitiveBuilder, RecordBatch, StringBuilder, + StringDictionaryBuilder, StructBuilder, UnionBuilder, }; use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; use arrow::datatypes::{ - ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION, + ArrowNumericType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, Fields, + Float64Type, TimeUnit, }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; - use chrono::NaiveDate; use insta::assert_snapshot; use rand::Rng; @@ -5040,6 +5715,52 @@ mod tests { assert_eq!(actual, &expected); } + #[test] + fn test_format_timestamp_type_for_error_and_bounds() { + // format helper + let ts_ns = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Nanosecond, + None, + )); + assert_eq!(ts_ns, "Timestamp(ns)"); + + let ts_us = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Microsecond, + None, + )); + assert_eq!(ts_us, "Timestamp(us)"); + + // ensure_timestamp_in_bounds: Date32 non-overflow + let ok = ensure_timestamp_in_bounds( + 1000, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(ok.is_ok()); + + // Date32 overflow -- known large day value (9999-12-31 -> 2932896) + let err = ensure_timestamp_in_bounds( + 2932896, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err.is_err()); + let msg = err.unwrap_err().to_string(); + assert!(msg.contains("Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range")); + + // Date64 overflow for ns (millis * 1_000_000) + let overflow_millis: i64 = (i64::MAX / NANOS_PER_MILLISECOND) + 1; + let err2 = ensure_timestamp_in_bounds( + overflow_millis, + NANOS_PER_MILLISECOND, + &DataType::Date64, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err2.is_err()); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); @@ -5142,6 +5863,27 @@ mod tests { ]); assert_eq!(&arr, actual_list_arr); + + // ListView + let arr = + ListViewArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + ])]); + + let sv = ScalarValue::ListView(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + let actual_list_arr = actual_arr.as_list_view::(); + + let arr = ListViewArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(1), None, Some(2)]), + ]); + + assert_eq!(&arr, actual_list_arr); } #[test] @@ -5171,6 +5913,91 @@ mod tests { assert_eq!(empty_array.len(), 0); } + #[test] + fn test_to_array_of_size_list_size_one() { + // size=1 takes the fast path (Arc::clone) + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(20), + ])]); + let sv = ScalarValue::List(Arc::new(arr.clone())); + let result = sv.to_array_of_size(1).unwrap(); + assert_eq!(result.as_list::(), &arr); + } + + #[test] + fn test_to_array_of_size_list_empty_inner() { + // A list scalar containing an empty list: [[]] + let arr = ListArray::from_iter_primitive::(vec![Some(vec![])]); + let sv = ScalarValue::List(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let result_list = result.as_list::(); + assert_eq!(result_list.len(), 3); + for i in 0..3 { + assert_eq!(result_list.value(i).len(), 0); + } + } + + #[test] + fn test_to_array_of_size_large_list() { + let arr = + LargeListArray::from_iter_primitive::(vec![Some(vec![ + Some(100), + Some(200), + ])]); + let sv = ScalarValue::LargeList(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + ]); + assert_eq!(result.as_list::(), &expected); + } + + #[test] + fn test_list_to_array_of_size_multi_row() { + // Call list_to_array_of_size directly with arr.len() > 1 + let arr = Int32Array::from(vec![Some(10), None, Some(30)]); + let result = ScalarValue::list_to_array_of_size(&arr, 3).unwrap(); + let result = result.as_primitive::(); + assert_eq!( + result.iter().collect::>(), + vec![ + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + ] + ); + } + + #[test] + fn test_to_array_of_size_null_list() { + let dt = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); + let sv = ScalarValue::try_from(&dt).unwrap(); + let result = sv.to_array_of_size(3).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.null_count(), 3); + } + + /// See https://github.com/apache/datafusion/issues/18870 + #[test] + fn test_to_array_of_size_for_none_fsb() { + let sv = ScalarValue::FixedSizeBinary(5, None); + let result = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + assert_eq!(result.as_fixed_size_binary().values().len(), 10); + } + #[test] fn test_list_to_array_string() { let scalars = vec![ @@ -5196,29 +6023,11 @@ mod tests { values .into_iter() .map(|v| { - let arr = if v.is_some() { - Arc::new( - GenericListArray::::from_iter_primitive::( - vec![v], - ), - ) - } else if O::IS_LARGE { - new_null_array( - &DataType::LargeList(Arc::new(Field::new_list_field( - DataType::Int64, - true, - ))), - 1, - ) - } else { - new_null_array( - &DataType::List(Arc::new(Field::new_list_field( - DataType::Int64, - true, - ))), - 1, - ) - }; + let arr = Arc::new(GenericListArray::::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![v])) as ArrayRef; if O::IS_LARGE { ScalarValue::LargeList(arr.as_list::().to_owned().into()) @@ -5229,6 +6038,29 @@ mod tests { .collect() } + fn build_list_view( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = Arc::new(GenericListViewArray::::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![v])) as ArrayRef; + + if O::IS_LARGE { + ScalarValue::LargeListView( + arr.as_list_view::().to_owned().into(), + ) + } else { + ScalarValue::ListView(arr.as_list_view::().to_owned().into()) + } + }) + .collect() + } + #[test] fn test_iter_to_array_fixed_size_list() { let field = Arc::new(Field::new_list_field(DataType::Int32, true)); @@ -5357,13 +6189,13 @@ mod tests { #[test] fn iter_to_array_primitive_test() { + // List // List[[1,2,3]], List[null], List[[4,5]] let scalars = build_list::(vec![ Some(vec![Some(1), Some(2), Some(3)]), None, Some(vec![Some(4), Some(5)]), ]); - let array = ScalarValue::iter_to_array(scalars).unwrap(); let list_array = as_list_array(&array).unwrap(); // List[[1,2,3], null, [4,5]] @@ -5374,20 +6206,57 @@ mod tests { ]); assert_eq!(list_array, &expected); + // LargeList + // List[[1,2,3]], List[null], List[[4,5]] let scalars = build_list::(vec![ Some(vec![Some(1), Some(2), Some(3)]), None, Some(vec![Some(4), Some(5)]), ]); - let array = ScalarValue::iter_to_array(scalars).unwrap(); - let list_array = as_large_list_array(&array).unwrap(); + let large_list_array = as_large_list_array(&array).unwrap(); let expected = LargeListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), None, Some(vec![Some(4), Some(5)]), ]); - assert_eq!(list_array, &expected); + assert_eq!(large_list_array, &expected); + + // ListView + // ListView[[1,2,3]], ListView[null], ListView[[4,5]] + let scalars = build_list_view::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_view_array = as_list_view_array(&array).unwrap(); + // ListView[[1,2,3], null, [4,5]] + let expected = ListViewArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_view_array, &expected); + + // LargeListView + // LargeListView[[1,2,3]], LargeListView[null], LargeListView[[4,5]] + let scalars = build_list_view::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let large_list_view_array = as_large_list_view_array(&array).unwrap(); + // LargeListView[[1,2,3], null, [4,5]] + let expected = LargeListViewArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(large_list_view_array, &expected); } #[test] @@ -5430,16 +6299,36 @@ mod tests { ])); let fsl_array: ArrayRef = - Arc::new(ListArray::from_iter_primitive::(vec![ + Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ], + 3, + )); + let list_view_array: ArrayRef = + Arc::new(ListViewArray::from_iter_primitive::(vec![ Some(vec![Some(0), Some(1), Some(2)]), None, - Some(vec![Some(3), None, Some(5)]), + Some(vec![None, Some(5)]), ])); - for arr in [list_array, fsl_array] { + for arr in [list_array, fsl_array, list_view_array] { for i in 0..arr.len() { - let scalar = - ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); + let slice = arr.slice(i, 1); + let scalar = match arr.data_type() { + DataType::List(_) => { + ScalarValue::List(slice.as_list::().to_owned().into()) + } + DataType::FixedSizeList(_, _) => ScalarValue::FixedSizeList( + slice.as_fixed_size_list().to_owned().into(), + ), + DataType::ListView(_) => ScalarValue::ListView( + slice.as_list_view::().to_owned().into(), + ), + _ => unreachable!(), + }; assert!(scalar.eq_array(&arr, i).unwrap()); } } @@ -5471,6 +6360,68 @@ mod tests { Ok(()) } + #[test] + fn scalar_add_trait_null_test() -> Result<()> { + let int_value = ScalarValue::Int32(Some(42)); + + assert_eq!( + int_value.add(ScalarValue::Int32(None))?, + ScalarValue::Int32(None) + ); + + Ok(()) + } + + #[test] + fn scalar_add_trait_wrapping_overflow_test() -> Result<()> { + let int_value = ScalarValue::Int32(Some(i32::MAX)); + let one = ScalarValue::Int32(Some(1)); + + assert_eq!(int_value.add(one)?, ScalarValue::Int32(Some(i32::MIN))); + + Ok(()) + } + + #[test] + fn scalar_add_trait_decimal_scale_test() -> Result<()> { + let decimal = ScalarValue::Decimal128(Some(123), 10, 2); + let decimal_2 = ScalarValue::Decimal128(Some(4), 9, 1); + + assert_eq!( + decimal.add(decimal_2)?, + ScalarValue::Decimal128(Some(163), 11, 2) + ); + + Ok(()) + } + + #[test] + fn scalar_add_trait_decimal256_scale_test() -> Result<()> { + let decimal = ScalarValue::Decimal256(Some(i256::from(123)), 10, 2); + let decimal_2 = ScalarValue::Decimal256(Some(i256::from(4)), 9, 1); + + assert_eq!( + decimal.add(decimal_2)?, + ScalarValue::Decimal256(Some(i256::from(163)), 11, 2) + ); + + Ok(()) + } + + #[test] + fn scalar_add_trait_decimal_negative_scale_test() -> Result<()> { + let decimal = ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN); + let decimal_2 = + ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN); + + assert_eq!( + decimal.add(decimal_2)?, + ScalarValue::Decimal128(Some(2), DECIMAL128_MAX_PRECISION, i8::MIN) + ); + + Ok(()) + } + #[test] fn scalar_sub_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -5526,7 +6477,10 @@ mod tests { .sub_checked(&int_value_2) .unwrap_err() .strip_backtrace(); - assert_eq!(err, "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808") + assert_eq!( + err, + "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808" + ) } #[test] @@ -5567,6 +6521,43 @@ mod tests { Ok(()) } + #[test] + fn scalar_decimal_add_overflow_test() { + check_scalar_decimal_add_overflow::( + ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, 0), + ); + check_scalar_decimal_add_overflow::( + ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ScalarValue::Decimal256(Some(i256::ONE), DECIMAL256_MAX_PRECISION, 0), + ); + } + + #[test] + fn scalar_decimal_in_place_add_error_preserves_lhs() { + let mut lhs = + ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0); + let original = lhs.clone(); + + let err = lhs + .try_add_checked_in_place(&ScalarValue::Decimal128( + Some(1), + DECIMAL128_MAX_PRECISION, + 0, + )) + .unwrap_err() + .strip_backtrace(); + + assert_eq!( + err, + format!( + "Arrow error: Arithmetic overflow: Overflow happened on: {} + 1", + i128::MAX + ) + ); + assert_eq!(lhs, original); + } + // Verifies that ScalarValue has the same behavior with compute kernel when it overflows. fn check_scalar_add_overflow(left: ScalarValue, right: ScalarValue) where @@ -5583,6 +6574,22 @@ mod tests { assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } + // Verifies the decimal fast path preserves the same overflow behavior as Arrow kernels. + fn check_scalar_decimal_add_overflow(left: ScalarValue, right: ScalarValue) + where + T: ArrowPrimitiveType, + { + let scalar_result = left.add(&right); + + let left_array = left.to_array().expect("Failed to convert to array"); + let right_array = right.to_array().expect("Failed to convert to array"); + let arrow_left_array = left_array.as_primitive::(); + let arrow_right_array = right_array.as_primitive::(); + let arrow_result = add_wrapping(arrow_left_array, arrow_right_array); + + assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); + } + #[test] fn test_interval_add_timestamp() -> Result<()> { let interval = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { @@ -5674,12 +6681,16 @@ mod tests { assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(decimal_value - .eq_array(&array, 5) - .expect("Failed to compare arrays")); + assert!( + decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays") + ); // test try from array assert_eq!( decimal_value, @@ -5724,18 +6735,24 @@ mod tests { assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); - assert!(ScalarValue::try_new_decimal128(1, 10, 2) - .unwrap() - .eq_array(&array, 0) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(2, 10, 2) - .unwrap() - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(3, 10, 2) - .unwrap() - .eq_array(&array, 2) - .expect("Failed to compare arrays")); + assert!( + ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2) + .expect("Failed to compare arrays") + ); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -6010,6 +7027,40 @@ mod tests { ), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::ListView(Arc::new(ListViewArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]))); + let b = ScalarValue::ListView(Arc::new(ListViewArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]))); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::LargeListView(Arc::new( + LargeListViewArray::from_iter_primitive::(vec![Some( + vec![None, Some(2), Some(3)], + )]), + )); + let b = + ScalarValue::LargeListView(Arc::new( + LargeListViewArray::from_iter_primitive::(vec![Some( + vec![Some(1), Some(2), Some(3)], + )]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); } #[test] @@ -6171,8 +7222,6 @@ mod tests { } #[test] - // despite clippy claiming they are useless, the code doesn't compile otherwise. - #[allow(clippy::useless_vec)] fn scalar_iter_to_array_boolean() { check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); @@ -6222,12 +7271,12 @@ mod tests { check_scalar_iter_binary!( Binary, BinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); check_scalar_iter_binary!( LargeBinary, LargeBinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); } @@ -6359,6 +7408,30 @@ mod tests { ); assert_eq!(expected, scalar); assert!(expected.is_null()); + + // Test for ListView + let data_type = &DataType::ListView(Arc::clone(&inner_field)); + let scalar: ScalarValue = data_type.try_into().unwrap(); + let expected = ScalarValue::ListView( + new_null_array(data_type, 1) + .as_list_view::() + .to_owned() + .into(), + ); + assert_eq!(expected, scalar); + assert!(expected.is_null()); + + // Test for LargeListView + let data_type = &DataType::LargeListView(Arc::clone(&inner_field)); + let scalar: ScalarValue = data_type.try_into().unwrap(); + let expected = ScalarValue::LargeListView( + new_null_array(data_type, 1) + .as_list_view::() + .to_owned() + .into(), + ); + assert_eq!(expected, scalar); + assert!(expected.is_null()); } #[test] @@ -6448,6 +7521,8 @@ mod tests { size_of::>() + (9 * size_of::()) + sv_size, ); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // ScalarValue has interior mutability but is intentionally used as hash key let mut s = HashSet::with_capacity(0); // do NOT clone `sv` here because this may shrink the vector capacity s.insert(v.pop().unwrap()); @@ -6680,7 +7755,9 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), + !scalar + .eq_array(&array, other_index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -7069,6 +8146,34 @@ mod tests { builder.append(true); Arc::new(builder.finish()) }, + // list view array + { + let values_builder = StringBuilder::new(); + let mut builder = ListViewBuilder::new(values_builder); + // [A, B] + builder.values().append_value("A"); + builder.values().append_value("B"); + builder.append(true); + // [ ] (empty list) + builder.append(true); + // Null + builder.append(false); + Arc::new(builder.finish()) + }, + // large list view array + { + let values_builder = StringBuilder::new(); + let mut builder = LargeListViewBuilder::new(values_builder); + // [A, B] + builder.values().append_value("A"); + builder.values().append_value("B"); + builder.append(true); + // [ ] (empty list) + builder.append(true); + // Null + builder.append(false); + Arc::new(builder.finish()) + }, // map { let string_builder = StringBuilder::new(); @@ -7108,6 +8213,31 @@ mod tests { } } + #[test] + fn roundtrip_run_array() { + // Comparison logic in round_trip_through_scalar doesn't work for RunArrays + // so we have a custom test for them + // TODO: https://github.com/apache/arrow-rs/pull/9213 might fix this ^ + let run_ends = Int16Array::from(vec![2, 3]); + let values = Int64Array::from(vec![Some(1), None]); + let run_array = RunArray::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.downcast::().unwrap(); + + let expected_values = run_array.into_iter().collect::>(); + + for i in 0..run_array.len() { + let scalar = ScalarValue::try_from_array(&run_array, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.data_type(), run_array.data_type()); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!( + array.into_iter().collect::>(), + expected_values[i..i + 1] + ); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -7566,6 +8696,38 @@ mod tests { }, DataType::LargeList(Arc::new(Field::new("element", DataType::Int64, true))), ); + check_scalar_cast( + { + let element_field = + Arc::new(Field::new("element", DataType::Int32, true)); + + let mut builder = + ListViewBuilder::new(Int32Builder::new()).with_field(element_field); + builder.append_value([Some(1)]); + builder.append(true); + + ScalarValue::ListView(Arc::new(builder.finish())) + }, + DataType::ListView(Arc::new(Field::new("element", DataType::Int64, true))), + ); + check_scalar_cast( + { + let element_field = + Arc::new(Field::new("element", DataType::Int32, true)); + + let mut builder = LargeListViewBuilder::new(Int32Builder::new()) + .with_field(element_field); + builder.append_value([Some(1)]); + builder.append(true); + + ScalarValue::LargeListView(Arc::new(builder.finish())) + }, + DataType::LargeListView(Arc::new(Field::new( + "element", + DataType::Int64, + true, + ))), + ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` @@ -7605,7 +8767,6 @@ mod tests { } #[test] - #[allow(arithmetic_overflow)] // we want to test them fn test_scalar_negative_overflows() -> Result<()> { macro_rules! test_overflow_on_value { ($($val:expr),* $(,)?) => {$( @@ -7922,6 +9083,42 @@ mod tests { ScalarValue::Decimal256(Some(10.into()), 1, 0), 5, ), + // Temporal types + ( + ScalarValue::Date32(Some(0)), + ScalarValue::Date32(Some(10)), + 10, + ), + ( + ScalarValue::Date32(Some(10)), + ScalarValue::Date32(Some(0)), + 10, + ), + ( + ScalarValue::Date64(Some(1000)), + ScalarValue::Date64(Some(5000)), + 4000, + ), + ( + ScalarValue::TimestampSecond(Some(100), None), + ScalarValue::TimestampSecond(Some(200), None), + 100, + ), + ( + ScalarValue::TimestampMillisecond(Some(1000), None), + ScalarValue::TimestampMillisecond(Some(5000), None), + 4000, + ), + ( + ScalarValue::TimestampMicrosecond(Some(0), None), + ScalarValue::TimestampMicrosecond(Some(1_000_000), None), + 1_000_000, + ), + ( + ScalarValue::TimestampNanosecond(Some(1_000_000_000), None), + ScalarValue::TimestampNanosecond(Some(2_000_000_000), None), + 1_000_000_000, + ), ]; for (lhs, rhs, expected) in cases.iter() { let distance = lhs.distance(rhs).unwrap(); @@ -7984,8 +9181,6 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), ), - (ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(1))), - (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))), ( ScalarValue::Decimal128(Some(123), 5, 5), ScalarValue::Decimal128(Some(120), 5, 3), @@ -8192,7 +9387,7 @@ mod tests { )))), ]; - let check_array = |array| { + let check_array = |array: Arc| { let is_null = is_null(&array).unwrap(); assert_eq!(is_null, BooleanArray::from(vec![true, false, false])); @@ -8257,6 +9452,21 @@ mod tests { "); } + #[test] + fn test_list_view_display() { + let s = ScalarValue::ListView( + ListViewArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]) + .into(), + ); + + assert_eq!(s.to_string(), "[1, , 3]"); + assert_eq!(format!("{s:?}"), "ListView([1, , 3])"); + } + #[test] fn test_null_bug() { let field_a = Field::new("a", DataType::Int32, true); @@ -8621,6 +9831,19 @@ mod tests { assert!(dense_scalar.is_null()); } + #[test] + fn cast_date_to_timestamp_overflow_returns_error() { + let scalar = ScalarValue::Date32(Some(i32::MAX)); + let err = scalar + .cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .expect_err("expected cast to fail"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } + #[test] fn null_dictionary_scalar_produces_null_dictionary_array() { let dictionary_scalar = ScalarValue::Dictionary( @@ -8702,13 +9925,17 @@ mod tests { 42, )) .unwrap(), + ScalarValue::try_new_null(&DataType::ListView(Arc::clone(&field_ref))) + .unwrap(), + ScalarValue::try_new_null(&DataType::LargeListView(Arc::clone(&field_ref))) + .unwrap(), ScalarValue::try_new_null(&DataType::Struct( vec![Arc::clone(&field_ref)].into(), )) .unwrap(), ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), ScalarValue::try_new_null(&DataType::Union( - UnionFields::new(vec![42], vec![field_ref]), + UnionFields::try_new(vec![42], vec![field_ref]).unwrap(), UnionMode::Dense, )) .unwrap(), @@ -8794,6 +10021,41 @@ mod tests { _ => panic!("Expected List"), } + let list_field = Field::new_list_field(DataType::Int32, true); + let list_result = + ScalarValue::new_default(&DataType::LargeList(Arc::new(list_field.clone()))) + .unwrap(); + match list_result { + ScalarValue::LargeList(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.value_length(0), 0); // empty list + } + _ => panic!("Expected LargeList"), + } + + let list_result = + ScalarValue::new_default(&DataType::ListView(Arc::new(list_field.clone()))) + .unwrap(); + match list_result { + ScalarValue::ListView(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.value_size(0), 0); // empty list + } + _ => panic!("Expected ListView"), + } + + let list_result = ScalarValue::new_default(&DataType::LargeListView(Arc::new( + list_field.clone(), + ))) + .unwrap(); + match list_result { + ScalarValue::LargeListView(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.value_size(0), 0); // empty list + } + _ => panic!("Expected LargeListView"), + } + // Test struct type let struct_fields = Fields::from(vec![ Field::new("a", DataType::Int32, false), @@ -8811,13 +10073,14 @@ mod tests { } // Test union type - let union_fields = UnionFields::new( + let union_fields = UnionFields::try_new( vec![0, 1], vec![ Field::new("i32", DataType::Int32, false), Field::new("f64", DataType::Float64, false), ], - ); + ) + .unwrap(); let union_result = ScalarValue::new_default(&DataType::Union( union_fields.clone(), UnionMode::Sparse, @@ -8902,6 +10165,30 @@ mod tests { )))), None ); + assert_eq!( + ScalarValue::min(&DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); + assert_eq!( + ScalarValue::min(&DataType::ListView(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); + assert_eq!( + ScalarValue::min(&DataType::LargeListView(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); } #[test] @@ -8978,6 +10265,22 @@ mod tests { )]))), None ); + assert_eq!( + ScalarValue::max(&DataType::ListView(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); + assert_eq!( + ScalarValue::max(&DataType::LargeListView(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); } #[test] @@ -9046,6 +10349,196 @@ mod tests { } } + #[test] + fn test_views_minimize_memory() { + let value = "this string is longer than 12 bytes".to_string(); + + let scalar = ScalarValue::Utf8View(Some(value.clone())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_string_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + // Ensure we only have a single copy of the value string + assert_eq!(value.len(), buffers[0].len()); + + // Same but for BinaryView + let scalar = ScalarValue::BinaryView(Some(value.bytes().collect())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_binary_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + assert_eq!(value.len(), buffers[0].len()); + } + + #[test] + fn test_to_array_of_size_run_end_encoded() { + fn run_test() { + let value = Box::new(ScalarValue::Float32(Some(1.0))); + let size = 5; + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", R::DATA_TYPE, false).into(), + Field::new("values", DataType::Float32, true).into(), + value.clone(), + ); + let array = scalar.to_array_of_size(size).unwrap(); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!(vec![Some(1.0); size], array.into_iter().collect::>()); + assert_eq!(1, array.values().len()); + } + + run_test::(); + run_test::(); + run_test::(); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let err = scalar.to_array_of_size(i16::MAX as usize + 10).unwrap_err(); + assert_eq!( + "Execution error: Cannot construct RunArray of size 32777: Overflows run-ends type Int16", + err.to_string() + ) + } + + #[test] + fn test_eq_array_run_end_encoded() { + let run_ends = Int16Array::from(vec![1, 3]); + let values = Float32Array::from(vec![None, Some(1.0)]); + let run_array = + Arc::new(RunArray::try_new(&run_ends, &values).unwrap()) as ArrayRef; + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + assert!(scalar.eq_array(&run_array, 0).unwrap()); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + assert!(scalar.eq_array(&run_array, 1).unwrap()); + assert!(scalar.eq_array(&run_array, 2).unwrap()); + + // value types must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float64, true).into(), + Box::new(ScalarValue::Float64(Some(1.0))), + ); + let err = scalar.eq_array(&run_array, 1).unwrap_err(); + let expected = "Internal error: could not cast array of type Float32 to arrow_array::array::primitive_array::PrimitiveArray"; + assert!(err.to_string().starts_with(expected)); + + // run ends type must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + let err = scalar.eq_array(&run_array, 0).unwrap_err(); + let expected = "Internal error: could not cast array of type RunEndEncoded(\"run_ends\": non-null Int16, \"values\": Float32) to arrow_array::array::run_array::RunArray"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_iter_to_array_run_end_encoded() { + let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int16, false)); + let values_field = Arc::new(Field::new("values", DataType::Int64, true)); + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(None)), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ]; + + let run_array = ScalarValue::iter_to_array(scalars).unwrap(); + let expected = RunArray::try_new( + &Int16Array::from(vec![2, 3, 6]), + &Int64Array::from(vec![Some(1), None, Some(2)]), + ) + .unwrap(); + assert_eq!(&expected as &dyn Array, run_array.as_ref()); + + // inconsistent run-ends type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int32 }, Field { name: \"values\", data_type: Int64, nullable: true }, Int64(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent value type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Field::new("values", DataType::Int32, true).into(), + Box::new(ScalarValue::Int32(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int16 }, Field { name: \"values\", data_type: Int32, nullable: true }, Int32(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent scalars type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::Int64(Some(1)), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: Int64(1)"; + assert!(err.to_string().starts_with(expected)); + } + #[test] fn test_convert_array_to_scalar_vec() { // 1: Regular ListArray @@ -9166,5 +10659,52 @@ mod tests { ]), ] ); + + // 6: Regular ListViewArray + let list = ListViewArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); + + // 7: Regular LargeListViewArray + let large_list = + LargeListViewArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&large_list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index 56daee904514a..045b5778243df 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -83,6 +83,7 @@ impl ScalarStructBuilder { } /// Add the specified field and `ScalarValue` to the struct. + #[expect(clippy::needless_pass_by_value)] // Skip for public API's compatibility pub fn with_scalar(self, field: impl IntoFieldRef, value: ScalarValue) -> Self { // valid scalar value should not fail let array = value.to_array().unwrap(); diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index da298c20ebcb4..29b9c36c0a7ea 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -22,17 +22,40 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; use crate::error::_plan_err; +use crate::utils::aggregate::precision_add; use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. #[derive(Clone, PartialEq, Eq, Default, Copy)] pub enum Precision { - /// The exact value is known + /// The exact value is known. Used for guaranteeing correctness. + /// + /// Comes from definitive sources such as: + /// - Parquet file metadata (row counts, byte sizes) + /// - In-memory RecordBatch data (actual row counts, byte sizes, null counts) + /// - and more... Exact(T), - /// The value is not known exactly, but is likely close to this value + /// The value is not known exactly, but is likely close to this value. + /// Used for cost-based optimizations. + /// + /// Some operations that would result in `Inexact(T)` would be: + /// - Applying a filter (selectivity is unknown) + /// - Mixing exact and inexact values in arithmetic + /// - and more... Inexact(T), - /// Nothing is known about the value + /// Nothing is known about the value. This is the default state. + /// + /// Acts as an absorbing element in arithmetic -> any operation + /// involving `Absent` yields `Absent`. [`Precision::to_inexact`] + /// on `Absent` returns `Absent`, not `Inexact` — it represents + /// a fundamentally different state. + /// + /// Common sources include: + /// - Data sources without statistics + /// - Parquet columns missing from file metadata + /// - Statistics that cannot be derived for an operation (e.g., + /// `distinct_count` after a union, `total_byte_size` for joins) #[default] Absent, } @@ -180,24 +203,74 @@ impl Precision { } impl Precision { + fn sum_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 => DataType::Int64, + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => DataType::UInt64, + _ => data_type.clone(), + } + } + + fn cast_scalar_to_sum_type(value: &ScalarValue) -> Result { + let source_type = value.data_type(); + let target_type = Self::sum_data_type(&source_type); + if source_type == target_type { + Ok(value.clone()) + } else { + value.cast_to(&target_type) + } + } + /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values, /// conservatively propagating exactness information. If one of the input /// values is [`Precision::Absent`], the result is `Absent` too. + /// + /// Uses [`ScalarValue::add_checked`] so that integer overflow returns + /// an error (mapped to `Absent`) instead of silently wrapping. + /// + /// For performance-sensitive paths prefer `precision_add` which + /// avoids the Arrow array round-trip. pub fn add(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => { - a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent) - } + (Precision::Exact(a), Precision::Exact(b)) => a + .add_checked(b) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) | (Precision::Inexact(a), Precision::Inexact(b)) => a - .add(b) + .add_checked(b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + (_, _) => Precision::Absent, + } + } + + /// Casts integer values to the wider SQL `SUM` return type. + /// + /// This narrows overflow risk when `sum_value` statistics are merged: + /// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`. + pub fn cast_to_sum_type(&self) -> Precision { + match (self.is_exact(), self.get_value()) { + (Some(true), Some(value)) => Self::cast_scalar_to_sum_type(value) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), + (Some(false), Some(value)) => Self::cast_scalar_to_sum_type(value) .map(Precision::Inexact) .unwrap_or(Precision::Absent), (_, _) => Precision::Absent, } } + /// SUM-style addition with integer widening to match SQL `SUM` return + /// types for smaller integral inputs. + pub fn add_for_sum(&self, other: &Precision) -> Precision { + let mut lhs = self.cast_to_sum_type(); + let rhs = other.cast_to_sum_type(); + precision_add(&mut lhs, &rhs); + lhs + } + /// Calculates the difference of two (possibly inexact) [`ScalarValue`] values, /// conservatively propagating exactness information. If one of the input /// values is [`Precision::Absent`], the result is `Absent` too. @@ -283,9 +356,14 @@ impl From> for Precision { /// and the transformations output are not always predictable. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Statistics { - /// The number of table rows. + /// The number of rows estimated to be scanned. pub num_rows: Precision, - /// Total bytes of the table rows. + /// The total bytes of the output data. + /// + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. pub total_byte_size: Precision, /// Statistics on a column level. /// @@ -294,6 +372,27 @@ pub struct Statistics { pub column_statistics: Vec, } +/// Fallback to use when NDV overlap can not be estimated from column bounds. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum NdvFallback { + /// Use the larger input NDV. This is the conservative default for + /// related fragments such as files from the same table. + #[default] + Max, + /// Sum the input NDVs. This is a conservative upper bound for + /// independent inputs such as `UNION ALL`. + Sum, +} + +impl NdvFallback { + fn merge(self, left: usize, right: usize) -> usize { + match self { + Self::Max => usize::max(left, right), + Self::Sum => left.saturating_add(right), + } + } +} + impl Default for Statistics { /// Returns a new [`Statistics`] instance with all fields set to unknown /// and no columns. @@ -317,6 +416,31 @@ impl Statistics { } } + /// Calculates `total_byte_size` based on the schema and `num_rows`. + /// If any of the columns has non-primitive width, `total_byte_size` is set to inexact. + pub fn calculate_total_byte_size(&mut self, schema: &Schema) { + let mut row_size = Some(0); + for field in schema.fields() { + match field.data_type().primitive_width() { + Some(width) => { + row_size = row_size.map(|s| s + width); + } + None => { + row_size = None; + break; + } + } + } + match row_size { + None => { + self.total_byte_size = self.total_byte_size.to_inexact(); + } + Some(size) => { + self.total_byte_size = self.num_rows.multiply(&Precision::Exact(size)); + } + } + } + /// Returns an unbounded `ColumnStatistics` for each field in the schema. pub fn unknown_column(schema: &Schema) -> Vec { schema @@ -362,12 +486,17 @@ impl Statistics { /// For example, if we had statistics for columns `{"a", "b", "c"}`, /// projecting to `vec![2, 1]` would return statistics for columns `{"c", /// "b"}`. - pub fn project(mut self, projection: Option<&Vec>) -> Self { - let Some(projection) = projection else { + pub fn project(self, projection: Option<&impl AsRef<[usize]>>) -> Self { + let projection = projection.map(AsRef::as_ref); + self.project_impl(projection) + } + + fn project_impl(mut self, projection: Option<&[usize]>) -> Self { + let Some(projection) = projection.map(AsRef::as_ref) else { return self; }; - #[allow(clippy::large_enum_variant)] + #[expect(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -381,7 +510,7 @@ impl Statistics { .map(Slot::Present) .collect(); - for idx in projection { + for idx in projection.iter() { let next_idx = self.column_statistics.len(); let slot = std::mem::replace( columns.get_mut(*idx).expect("projection out of bounds"), @@ -477,15 +606,42 @@ impl Statistics { self.column_statistics = self .column_statistics .into_iter() - .map(ColumnStatistics::to_inexact) + .map(|cs| { + let mut cs = cs.to_inexact(); + // Scale byte_size by the row ratio + cs.byte_size = match cs.byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + }; + // NDV can never exceed the number of rows + if let Some(&rows) = self.num_rows.get_value() { + cs.distinct_count = cs.distinct_count.min(&Precision::Inexact(rows)); + } + cs + }) .collect(); - // Adjust the total_byte_size for the ratio of rows before and after, also marking it as inexact - self.total_byte_size = match &self.total_byte_size { - Precision::Exact(n) | Precision::Inexact(n) => { - let adjusted = (*n as f64 * ratio) as usize; - Precision::Inexact(adjusted) + + // Compute total_byte_size as sum of column byte_size values if all are present, + // otherwise fall back to scaling the original total_byte_size + let sum_scan_bytes: Option = self + .column_statistics + .iter() + .map(|cs| cs.byte_size.get_value().copied()) + .try_fold(0usize, |acc, val| val.map(|v| acc + v)); + + self.total_byte_size = match sum_scan_bytes { + Some(sum) => Precision::Inexact(sum), + None => { + // Fall back to scaling original total_byte_size if not all columns have byte_size + match &self.total_byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((*n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + } } - Precision::Absent => Precision::Absent, }; Ok(self) } @@ -495,24 +651,10 @@ impl Statistics { /// The method assumes that all statistics are for the same schema. /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. /// - /// Returns an error if the statistics do not match the specified schemas. - pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result - where - I: IntoIterator, - { - let mut items = items.into_iter(); - - let Some(init) = items.next() else { - return Ok(Statistics::new_unknown(schema)); - }; - items.try_fold(init.clone(), |acc: Statistics, item_stats: &Statistics| { - acc.try_merge(item_stats) - }) - } - - /// Merge this Statistics value with another Statistics value. + /// This method uses [`NdvFallback::Max`] when `distinct_count` overlap + /// can not be estimated from column bounds. /// - /// Returns an error if the statistics do not match (different schemas). + /// Returns an error if the statistics do not match the specified schemas. /// /// # Example /// ``` @@ -520,67 +662,110 @@ impl Statistics { /// # use arrow::datatypes::{Field, Schema, DataType}; /// # use datafusion_common::stats::Precision; /// let stats1 = Statistics::default() - /// .with_num_rows(Precision::Exact(1)) - /// .with_total_byte_size(Precision::Exact(2)) + /// .with_num_rows(Precision::Exact(10)) /// .add_column_statistics( /// ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Exact(3)) - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(5))), + /// .with_min_value(Precision::Exact(ScalarValue::from(1))) + /// .with_max_value(Precision::Exact(ScalarValue::from(100))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(500))), /// ); /// /// let stats2 = Statistics::default() - /// .with_num_rows(Precision::Exact(10)) - /// .with_total_byte_size(Precision::Inexact(20)) + /// .with_num_rows(Precision::Exact(20)) /// .add_column_statistics( /// ColumnStatistics::new_unknown() - /// // absent null count - /// .with_min_value(Precision::Exact(ScalarValue::from(40))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))), + /// .with_min_value(Precision::Exact(ScalarValue::from(5))) + /// .with_max_value(Precision::Exact(ScalarValue::from(200))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(1000))), /// ); /// - /// let merged_stats = stats1.try_merge(&stats2).unwrap(); - /// let expected_stats = Statistics::default() - /// .with_num_rows(Precision::Exact(11)) - /// .with_total_byte_size(Precision::Inexact(22)) // inexact in stats2 --> inexact - /// .add_column_statistics( - /// ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Absent) // missing from stats2 --> absent - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))), - /// ); + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let merged = Statistics::try_merge_iter( + /// &[stats1, stats2], + /// &schema, + /// ).unwrap(); /// - /// assert_eq!(merged_stats, expected_stats) + /// assert_eq!(merged.num_rows, Precision::Exact(30)); + /// assert_eq!(merged.column_statistics[0].min_value, + /// Precision::Exact(ScalarValue::from(1))); + /// assert_eq!(merged.column_statistics[0].max_value, + /// Precision::Exact(ScalarValue::from(200))); + /// assert_eq!(merged.column_statistics[0].sum_value, + /// Precision::Exact(ScalarValue::Int64(Some(1500)))); /// ``` - pub fn try_merge(self, other: &Statistics) -> Result { - let Self { - mut num_rows, - mut total_byte_size, - mut column_statistics, - } = self; - - // Accumulate statistics for subsequent items - num_rows = num_rows.add(&other.num_rows); - total_byte_size = total_byte_size.add(&other.total_byte_size); - - if column_statistics.len() != other.column_statistics.len() { - return _plan_err!( - "Cannot merge statistics with different number of columns: {} vs {}", - column_statistics.len(), - other.column_statistics.len() - ); + pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result + where + I: IntoIterator, + { + Self::try_merge_iter_with_ndv_fallback(items, schema, NdvFallback::Max) + } + + /// Same as [`Statistics::try_merge_iter`], but lets callers choose the + /// fallback used when `distinct_count` overlap can not be estimated. + pub fn try_merge_iter_with_ndv_fallback<'a, I>( + items: I, + schema: &Schema, + ndv_fallback: NdvFallback, + ) -> Result + where + I: IntoIterator, + { + let mut items = items.into_iter(); + let Some(first) = items.next() else { + return Ok(Statistics::new_unknown(schema)); + }; + let Some(second) = items.next() else { + return Ok(first.clone()); + }; + + let num_cols = first.column_statistics.len(); + let mut num_rows = first.num_rows; + let mut total_byte_size = first.total_byte_size; + let mut column_statistics = first.column_statistics.clone(); + for col_stats in &mut column_statistics { + cast_sum_value_to_sum_type_in_place(&mut col_stats.sum_value); } - for (item_col_stats, col_stats) in other - .column_statistics - .iter() - .zip(column_statistics.iter_mut()) - { - col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count); - col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); - col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); - col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); - col_stats.distinct_count = Precision::Absent; + // Merge the remaining items in a single pass. + for (i, stat) in std::iter::once(second).chain(items).enumerate() { + if stat.column_statistics.len() != num_cols { + return _plan_err!( + "Cannot merge statistics with different number of columns: {} vs {} (item {})", + num_cols, + stat.column_statistics.len(), + i + 1 + ); + } + num_rows = num_rows.add(&stat.num_rows); + total_byte_size = total_byte_size.add(&stat.total_byte_size); + + // Uses precision_add for sum (reuses the lhs accumulator for + // direct numeric addition), while preserving the NDV update + // ordering required by estimate_ndv_with_overlap. + for (col_stats, item_cs) in + column_statistics.iter_mut().zip(&stat.column_statistics) + { + col_stats.null_count = col_stats.null_count.add(&item_cs.null_count); + + // NDV must be computed before min/max update (needs pre-merge ranges) + col_stats.distinct_count = match ( + col_stats.distinct_count.get_value(), + item_cs.distinct_count.get_value(), + ) { + (Some(&l), Some(&r)) => Precision::Inexact( + estimate_ndv_with_overlap(col_stats, item_cs, l, r) + .unwrap_or_else(|| ndv_fallback.merge(l, r)), + ), + _ => Precision::Absent, + }; + precision_min(&mut col_stats.min_value, &item_cs.min_value); + precision_max(&mut col_stats.max_value, &item_cs.max_value); + precision_add_for_sum_in_place( + &mut col_stats.sum_value, + &item_cs.sum_value, + ); + col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size); + } } Ok(Statistics { @@ -591,6 +776,205 @@ impl Statistics { } } +/// Estimates the combined number of distinct values (NDV) when merging two +/// column statistics, using range overlap to avoid double-counting shared values. +/// +/// Assumes values are distributed uniformly within each input's +/// `[min, max]` range (the standard assumption when only summary +/// statistics are available). Under uniformity the fraction of an input's +/// distinct values that land in a sub-range equals the fraction of +/// the range that sub-range covers. +/// +/// The combined value space is split into three disjoint regions: +/// +/// ```text +/// |-- only A --|-- overlap --|-- only B --| +/// ``` +/// +/// * **Only in A/B** - values outside the other input's range +/// contribute `(1 - overlap_a) * NDV_a` and `(1 - overlap_b) * NDV_b`. +/// * **Overlap** - both inputs may produce values here. We take +/// `max(overlap_a * NDV_a, overlap_b * NDV_b)` rather than the +/// sum because values in the same sub-range are likely shared +/// (the smaller set is assumed to be a subset of the larger). +/// +/// The formula ranges between `[max(NDV_a, NDV_b), NDV_a + NDV_b]`, +/// from full overlap to no overlap. +/// +/// ```text +/// NDV = max(overlap_a * NDV_a, overlap_b * NDV_b) [intersection] +/// + (1 - overlap_a) * NDV_a [only in A] +/// + (1 - overlap_b) * NDV_b [only in B] +/// ``` +/// +/// Returns `None` when min/max are absent or distance is unsupported +/// (e.g. strings), in which case the caller should fall back to a simpler +/// estimate. +pub fn estimate_ndv_with_overlap( + left: &ColumnStatistics, + right: &ColumnStatistics, + ndv_left: usize, + ndv_right: usize, +) -> Option { + let left_min = left.min_value.get_value()?; + let left_max = left.max_value.get_value()?; + let right_min = right.min_value.get_value()?; + let right_max = right.max_value.get_value()?; + + let range_left = left_max.distance(left_min)?; + let range_right = right_max.distance(right_min)?; + + // Constant columns (range == 0) can't use the proportional overlap + // formula below, so check interval overlap directly instead. + if range_left == 0 || range_right == 0 { + let overlaps = left_min <= right_max && right_min <= left_max; + return Some(if overlaps { + usize::max(ndv_left, ndv_right) + } else { + ndv_left + ndv_right + }); + } + + let overlap_min = if left_min >= right_min { + left_min + } else { + right_min + }; + let overlap_max = if left_max <= right_max { + left_max + } else { + right_max + }; + + // Disjoint ranges: no overlap, NDVs are additive + if overlap_min > overlap_max { + return Some(ndv_left + ndv_right); + } + + let overlap_range = overlap_max.distance(overlap_min)? as f64; + + let overlap_left = overlap_range / range_left as f64; + let overlap_right = overlap_range / range_right as f64; + + let intersection = f64::max( + overlap_left * ndv_left as f64, + overlap_right * ndv_right as f64, + ); + let only_left = (1.0 - overlap_left) * ndv_left as f64; + let only_right = (1.0 - overlap_right) * ndv_right as f64; + + Some((intersection + only_left + only_right).round() as usize) +} + +/// Returns the minimum precision while not allocating a new value, +/// mirrors the semantics of `PartialOrd`. +#[inline] +fn precision_min(lhs: &mut Precision, rhs: &Precision) +where + T: Debug + Clone + PartialEq + Eq + PartialOrd, +{ + *lhs = match (std::mem::take(lhs), rhs) { + (Precision::Exact(left), Precision::Exact(right)) => { + if left <= *right { + Precision::Exact(left) + } else { + Precision::Exact(right.clone()) + } + } + (Precision::Exact(left), Precision::Inexact(right)) + | (Precision::Inexact(left), Precision::Exact(right)) + | (Precision::Inexact(left), Precision::Inexact(right)) => { + if left <= *right { + Precision::Inexact(left) + } else { + Precision::Inexact(right.clone()) + } + } + (_, _) => Precision::Absent, + }; +} + +/// Returns the maximum precision while not allocating a new value, +/// mirrors the semantics of `PartialOrd`. +#[inline] +fn precision_max(lhs: &mut Precision, rhs: &Precision) +where + T: Debug + Clone + PartialEq + Eq + PartialOrd, +{ + *lhs = match (std::mem::take(lhs), rhs) { + (Precision::Exact(left), Precision::Exact(right)) => { + if left >= *right { + Precision::Exact(left) + } else { + Precision::Exact(right.clone()) + } + } + (Precision::Exact(left), Precision::Inexact(right)) + | (Precision::Inexact(left), Precision::Exact(right)) + | (Precision::Inexact(left), Precision::Inexact(right)) => { + if left >= *right { + Precision::Inexact(left) + } else { + Precision::Inexact(right.clone()) + } + } + (_, _) => Precision::Absent, + }; +} + +#[inline] +fn cast_sum_value_to_sum_type_in_place(value: &mut Precision) { + let (is_exact, inner) = match std::mem::take(value) { + Precision::Exact(v) => (true, v), + Precision::Inexact(v) => (false, v), + Precision::Absent => return, + }; + let source_type = inner.data_type(); + let target_type = Precision::::sum_data_type(&source_type); + + let wrap_precision_fn: fn(ScalarValue) -> Precision = if is_exact { + Precision::Exact + } else { + Precision::Inexact + }; + + *value = if source_type == target_type { + wrap_precision_fn(inner) + } else { + inner + .cast_to(&target_type) + .map(wrap_precision_fn) + .unwrap_or(Precision::Absent) + }; +} + +#[inline] +fn precision_add_for_sum_in_place( + lhs: &mut Precision, + rhs: &Precision, +) { + let (value, wrap_fn): (&ScalarValue, fn(ScalarValue) -> Precision) = + match rhs { + Precision::Exact(v) => (v, Precision::Exact), + Precision::Inexact(v) => (v, Precision::Inexact), + Precision::Absent => { + *lhs = Precision::Absent; + return; + } + }; + let source_type = value.data_type(); + let target_type = Precision::::sum_data_type(&source_type); + if source_type == target_type { + precision_add(lhs, rhs); + } else { + let rhs = value + .cast_to(&target_type) + .map(wrap_fn) + .unwrap_or(Precision::Absent); + precision_add(lhs, &rhs); + } +} + /// Creates an estimate of the number of rows in the output using the given /// optional value and exactness flag. fn check_num_rows(value: Option, is_exact: bool) -> Precision { @@ -642,6 +1026,11 @@ impl Display for Statistics { } else { s }; + let s = if cs.byte_size != Precision::Absent { + format!("{} ScanBytes={}", s, cs.byte_size) + } else { + s + }; s + ")" }) @@ -667,10 +1056,33 @@ pub struct ColumnStatistics { pub max_value: Precision, /// Minimum value of column pub min_value: Precision, - /// Sum value of a column + /// Sum value of a column. + /// + /// For integral columns, values should be kept in SUM-compatible widened + /// types (`Int8/Int16/Int32 -> Int64`, `UInt8/UInt16/UInt32 -> UInt64`) to + /// reduce overflow risk during statistics propagation. + /// + /// Callers should prefer [`ColumnStatistics::with_sum_value`] for setting + /// this field and [`Precision::add_for_sum`] / + /// [`Precision::cast_to_sum_type`] for sum arithmetic. pub sum_value: Precision, /// Number of distinct values pub distinct_count: Precision, + /// Estimated size of this column's data in bytes for the output. + /// + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. + /// + /// Currently this is accurately calculated for primitive types only. + /// For complex types (like Utf8, List, Struct, etc), this value may be + /// absent or inexact (e.g. estimated from the size of the data in the source Parquet files). + /// + /// This value is automatically scaled when operations like limits or + /// filters reduce the number of rows (see [`Statistics::with_fetch`]). + pub byte_size: Precision, } impl ColumnStatistics { @@ -693,6 +1105,7 @@ impl ColumnStatistics { min_value: Precision::Absent, sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -716,7 +1129,19 @@ impl ColumnStatistics { /// Set the sum value pub fn with_sum_value(mut self, sum_value: Precision) -> Self { - self.sum_value = sum_value; + self.sum_value = match sum_value { + Precision::Exact(value) => { + Precision::::cast_scalar_to_sum_type(&value) + .map(Precision::Exact) + .unwrap_or(Precision::Absent) + } + Precision::Inexact(value) => { + Precision::::cast_scalar_to_sum_type(&value) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent) + } + Precision::Absent => Precision::Absent, + }; self } @@ -726,6 +1151,13 @@ impl ColumnStatistics { self } + /// Set the scan byte size + /// This should initially be set to the total size of the column. + pub fn with_byte_size(mut self, byte_size: Precision) -> Self { + self.byte_size = byte_size; + self + } + /// If the exactness of a [`ColumnStatistics`] instance is lost, this /// function relaxes the exactness of all information by converting them /// [`Precision::Inexact`]. @@ -735,6 +1167,7 @@ impl ColumnStatistics { self.min_value = self.min_value.to_inexact(); self.sum_value = self.sum_value.to_inexact(); self.distinct_count = self.distinct_count.to_inexact(); + self.byte_size = self.byte_size.to_inexact(); self } } @@ -861,6 +1294,45 @@ mod tests { assert_eq!(precision.add(&Precision::Absent), Precision::Absent); } + #[test] + fn test_add_for_sum_scalar_integer_widening() { + let precision = Precision::Exact(ScalarValue::Int32(Some(42))); + + assert_eq!( + precision.add_for_sum(&Precision::Exact(ScalarValue::Int32(Some(23)))), + Precision::Exact(ScalarValue::Int64(Some(65))), + ); + assert_eq!( + precision.add_for_sum(&Precision::Inexact(ScalarValue::Int32(Some(23)))), + Precision::Inexact(ScalarValue::Int64(Some(65))), + ); + } + + #[test] + fn test_add_for_sum_prevents_int32_overflow() { + let lhs = Precision::Exact(ScalarValue::Int32(Some(i32::MAX))); + let rhs = Precision::Exact(ScalarValue::Int32(Some(1))); + + assert_eq!( + lhs.add_for_sum(&rhs), + Precision::Exact(ScalarValue::Int64(Some(i64::from(i32::MAX) + 1))), + ); + } + + #[test] + fn test_add_for_sum_scalar_unsigned_integer_widening() { + let precision = Precision::Exact(ScalarValue::UInt32(Some(42))); + + assert_eq!( + precision.add_for_sum(&Precision::Exact(ScalarValue::UInt32(Some(23)))), + Precision::Exact(ScalarValue::UInt64(Some(65))), + ); + assert_eq!( + precision.add_for_sum(&Precision::Inexact(ScalarValue::UInt32(Some(23)))), + Precision::Inexact(ScalarValue::UInt64(Some(65))), + ); + } + #[test] fn test_sub() { let precision1 = Precision::Exact(42); @@ -961,9 +1433,11 @@ mod tests { Precision::Exact(ScalarValue::Int64(None)), ); // Overflow returns error - assert!(Precision::Exact(ScalarValue::Int32(Some(256))) - .cast_to(&DataType::Int8) - .is_err()); + assert!( + Precision::Exact(ScalarValue::Int32(Some(256))) + .cast_to(&DataType::Int8) + .is_err() + ); } #[test] @@ -976,15 +1450,13 @@ mod tests { // Precision is not copy (requires .clone()) let precision: Precision = Precision::Exact(ScalarValue::Int64(Some(42))); - // Clippy would complain about this if it were Copy - #[allow(clippy::redundant_clone)] let p2 = precision.clone(); assert_eq!(precision, p2); } #[test] fn test_project_none() { - let projection = None; + let projection: Option> = None; let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); assert_eq!(stats, make_stats(vec![10, 20, 30])); } @@ -1026,11 +1498,50 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(64))), sum_value: Precision::Exact(ScalarValue::Int64(Some(4600))), distinct_count: Precision::Exact(100), + byte_size: Precision::Exact(800), } } + fn make_single_i64_ndv_stats( + distinct_count: Precision, + min_value: Option, + max_value: Option, + ) -> Statistics { + let to_precision = |value| Precision::Exact(ScalarValue::Int64(Some(value))); + + Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(distinct_count) + .with_min_value( + min_value.map(to_precision).unwrap_or(Precision::Absent), + ) + .with_max_value( + max_value.map(to_precision).unwrap_or(Precision::Absent), + ), + ) + } + + fn merge_single_i64_ndv_distinct_count( + left: Statistics, + right: Statistics, + ndv_fallback: NdvFallback, + ) -> Precision { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + + Statistics::try_merge_iter_with_ndv_fallback( + [&left, &right], + &schema, + ndv_fallback, + ) + .unwrap() + .column_statistics[0] + .distinct_count + } + #[test] - fn test_try_merge_basic() { + fn test_try_merge() { // Create a schema with two columns let schema = Arc::new(Schema::new(vec![ Field::new("col1", DataType::Int32, false), @@ -1048,6 +1559,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ColumnStatistics { null_count: Precision::Exact(2), @@ -1055,6 +1567,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ], }; @@ -1069,6 +1582,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ColumnStatistics { null_count: Precision::Exact(3), @@ -1076,6 +1590,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(5))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ], }; @@ -1101,7 +1616,7 @@ mod tests { ); assert_eq!( col1_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(1100))) + Precision::Exact(ScalarValue::Int64(Some(1100))) ); // 500 + 600 let col2_stats = &summary_stats.column_statistics[1]; @@ -1116,7 +1631,7 @@ mod tests { ); assert_eq!( col2_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(2200))) + Precision::Exact(ScalarValue::Int64(Some(2200))) ); // 1000 + 1200 } @@ -1139,6 +1654,7 @@ mod tests { min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }], }; @@ -1151,6 +1667,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), }], }; @@ -1171,7 +1688,7 @@ mod tests { col_stats.min_value, Precision::Inexact(ScalarValue::Int32(Some(-10))) ); - assert!(matches!(col_stats.sum_value, Precision::Absent)); + assert_eq!(col_stats.sum_value, Precision::Absent); } #[test] @@ -1215,7 +1732,10 @@ mod tests { let items = vec![stats1, stats2]; let e = Statistics::try_merge_iter(&items, &schema).unwrap_err(); - assert_contains!(e.to_string(), "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1"); + assert_contains!( + e.to_string(), + "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1" + ); } #[test] @@ -1244,7 +1764,9 @@ mod tests { ); // Merge statistics - let merged_stats = stats1.try_merge(&stats2).unwrap(); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged_stats = + Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); // Verify the results assert_eq!(merged_stats.num_rows, Precision::Exact(25)); @@ -1260,66 +1782,486 @@ mod tests { col_stats.max_value, Precision::Exact(ScalarValue::Int32(Some(20))) ); - // Distinct count should be Absent after merge - assert_eq!(col_stats.distinct_count, Precision::Absent); + // Overlap-based NDV: ranges [1,10] and [5,20], overlap [5,10] + // range_left=9, range_right=15, overlap=5 + // overlap_left=5*(5/9)=2.78, overlap_right=7*(5/15)=2.33 + // result = max(2.78, 2.33) + (5-2.78) + (7-2.33) = 9.67 -> 10 + assert_eq!(col_stats.distinct_count, Precision::Inexact(10)); } #[test] - fn test_with_fetch_basic_preservation() { - // Test that column statistics and byte size are preserved (as inexact) when applying fetch - let original_stats = Statistics { - num_rows: Precision::Exact(1000), - total_byte_size: Precision::Exact(8000), - column_statistics: vec![ - ColumnStatistics { - null_count: Precision::Exact(10), - max_value: Precision::Exact(ScalarValue::Int32(Some(100))), - min_value: Precision::Exact(ScalarValue::Int32(Some(0))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(5050))), - distinct_count: Precision::Exact(50), - }, - ColumnStatistics { - null_count: Precision::Exact(20), - max_value: Precision::Exact(ScalarValue::Int64(Some(200))), - min_value: Precision::Exact(ScalarValue::Int64(Some(10))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(10100))), - distinct_count: Precision::Exact(75), - }, - ], - }; - - // Apply fetch of 100 rows (10% of original) - let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); - - // Check num_rows - assert_eq!(result.num_rows, Precision::Exact(100)); - - // Check total_byte_size is scaled proportionally and marked as inexact - // 100/1000 = 0.1, so 8000 * 0.1 = 800 - assert_eq!(result.total_byte_size, Precision::Inexact(800)); - - // Check column statistics are preserved but marked as inexact - assert_eq!(result.column_statistics.len(), 2); + fn test_try_merge_ndv_disjoint_ranges() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(20)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(30)))) + .with_distinct_count(Precision::Exact(8)), + ); - // First column - assert_eq!( - result.column_statistics[0].null_count, - Precision::Inexact(10) - ); - assert_eq!( - result.column_statistics[0].max_value, - Precision::Inexact(ScalarValue::Int32(Some(100))) - ); - assert_eq!( - result.column_statistics[0].min_value, - Precision::Inexact(ScalarValue::Int32(Some(0))) - ); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // No overlap -> sum of NDVs assert_eq!( - result.column_statistics[0].sum_value, - Precision::Inexact(ScalarValue::Int32(Some(5050))) + merged.column_statistics[0].distinct_count, + Precision::Inexact(13) ); + } + + #[test] + fn test_try_merge_ndv_identical_ranges() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(50)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(30)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // Full overlap -> max(50, 30) = 50 assert_eq!( - result.column_statistics[0].distinct_count, + merged.column_statistics[0].distinct_count, + Precision::Inexact(50) + ); + } + + #[test] + fn test_try_merge_ndv_partial_overlap() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(80)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(50)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(150)))) + .with_distinct_count(Precision::Exact(60)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // overlap=[50,100], range_left=100, range_right=100, overlap_range=50 + // overlap_left=80*(50/100)=40, overlap_right=60*(50/100)=30 + // result = max(40,30) + (80-40) + (60-30) = 40 + 40 + 30 = 110 + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(110) + ); + } + + #[test] + fn test_try_merge_ndv_missing_min_max() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown().with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown().with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // No min/max -> default fallback is max + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + } + + #[test] + fn test_try_merge_ndv_non_numeric_types() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "aaa".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "zzz".to_string(), + )))) + .with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "bbb".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "yyy".to_string(), + )))) + .with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // distance() unsupported for strings -> default fallback is max + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + } + + #[test] + fn test_try_merge_ndv_non_numeric_types_sum_fallback() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "aaa".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "zzz".to_string(), + )))) + .with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "bbb".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "yyy".to_string(), + )))) + .with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let merged = Statistics::try_merge_iter_with_ndv_fallback( + [&stats1, &stats2], + &schema, + NdvFallback::Sum, + ) + .unwrap(); + + // distance() unsupported for strings -> sum fallback is caller-selected + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(13) + ); + } + + #[test] + fn test_try_merge_ndv_constant_columns() { + // Same constant: [5,5]+[5,5] -> max + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(1) + ); + + // Different constants: [5,5]+[10,10] -> sum + let stats3 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + let stats4 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(1)), + ); + + let merged = Statistics::try_merge_iter([&stats3, &stats4], &schema).unwrap(); + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(2) + ); + } + + #[test] + fn test_try_merge_ndv_original_union_edge_cases() { + struct NdvTestCase { + name: &'static str, + left_ndv: Precision, + left_min: Option, + left_max: Option, + right_ndv: Precision, + right_min: Option, + right_max: Option, + expected: Precision, + } + + let cases = vec![ + NdvTestCase { + name: "disjoint ranges", + left_ndv: Precision::Exact(5), + left_min: Some(0), + left_max: Some(10), + right_ndv: Precision::Exact(3), + right_min: Some(20), + right_max: Some(30), + expected: Precision::Inexact(8), + }, + NdvTestCase { + name: "identical ranges", + left_ndv: Precision::Exact(10), + left_min: Some(0), + left_max: Some(100), + right_ndv: Precision::Exact(8), + right_min: Some(0), + right_max: Some(100), + expected: Precision::Inexact(10), + }, + NdvTestCase { + name: "partial overlap", + left_ndv: Precision::Exact(100), + left_min: Some(0), + left_max: Some(100), + right_ndv: Precision::Exact(50), + right_min: Some(50), + right_max: Some(150), + expected: Precision::Inexact(125), + }, + NdvTestCase { + name: "right contained in left", + left_ndv: Precision::Exact(100), + left_min: Some(0), + left_max: Some(100), + right_ndv: Precision::Exact(50), + right_min: Some(25), + right_max: Some(75), + expected: Precision::Inexact(100), + }, + NdvTestCase { + name: "same constant value", + left_ndv: Precision::Exact(1), + left_min: Some(5), + left_max: Some(5), + right_ndv: Precision::Exact(1), + right_min: Some(5), + right_max: Some(5), + expected: Precision::Inexact(1), + }, + NdvTestCase { + name: "different constant values", + left_ndv: Precision::Exact(1), + left_min: Some(5), + left_max: Some(5), + right_ndv: Precision::Exact(1), + right_min: Some(10), + right_max: Some(10), + expected: Precision::Inexact(2), + }, + NdvTestCase { + name: "left constant within right range", + left_ndv: Precision::Exact(1), + left_min: Some(5), + left_max: Some(5), + right_ndv: Precision::Exact(10), + right_min: Some(0), + right_max: Some(10), + expected: Precision::Inexact(10), + }, + NdvTestCase { + name: "left constant outside right range", + left_ndv: Precision::Exact(1), + left_min: Some(20), + left_max: Some(20), + right_ndv: Precision::Exact(10), + right_min: Some(0), + right_max: Some(10), + expected: Precision::Inexact(11), + }, + NdvTestCase { + name: "right constant within left range", + left_ndv: Precision::Exact(10), + left_min: Some(0), + left_max: Some(10), + right_ndv: Precision::Exact(1), + right_min: Some(5), + right_max: Some(5), + expected: Precision::Inexact(10), + }, + NdvTestCase { + name: "right constant outside left range", + left_ndv: Precision::Exact(10), + left_min: Some(0), + left_max: Some(10), + right_ndv: Precision::Exact(1), + right_min: Some(20), + right_max: Some(20), + expected: Precision::Inexact(11), + }, + NdvTestCase { + name: "missing bounds exact plus exact", + left_ndv: Precision::Exact(10), + left_min: None, + left_max: None, + right_ndv: Precision::Exact(5), + right_min: None, + right_max: None, + expected: Precision::Inexact(15), + }, + NdvTestCase { + name: "missing bounds exact plus inexact", + left_ndv: Precision::Exact(10), + left_min: None, + left_max: None, + right_ndv: Precision::Inexact(5), + right_min: None, + right_max: None, + expected: Precision::Inexact(15), + }, + NdvTestCase { + name: "missing bounds inexact plus inexact", + left_ndv: Precision::Inexact(7), + left_min: None, + left_max: None, + right_ndv: Precision::Inexact(3), + right_min: None, + right_max: None, + expected: Precision::Inexact(10), + }, + NdvTestCase { + name: "exact plus absent", + left_ndv: Precision::Exact(10), + left_min: None, + left_max: None, + right_ndv: Precision::Absent, + right_min: None, + right_max: None, + expected: Precision::Absent, + }, + NdvTestCase { + name: "inexact plus absent", + left_ndv: Precision::Inexact(4), + left_min: None, + left_max: None, + right_ndv: Precision::Absent, + right_min: None, + right_max: None, + expected: Precision::Absent, + }, + ]; + + for case in cases { + let actual = merge_single_i64_ndv_distinct_count( + make_single_i64_ndv_stats(case.left_ndv, case.left_min, case.left_max), + make_single_i64_ndv_stats(case.right_ndv, case.right_min, case.right_max), + NdvFallback::Sum, + ); + + assert_eq!(actual, case.expected, "case {} failed", case.name); + } + } + + #[test] + fn test_with_fetch_basic_preservation() { + // Test that column statistics and byte size are preserved (as inexact) when applying fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(5050))), + distinct_count: Precision::Exact(50), + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(10100))), + distinct_count: Precision::Exact(75), + byte_size: Precision::Exact(8000), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); + + // Check num_rows + assert_eq!(result.num_rows, Precision::Exact(100)); + + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.1 = 400, Column 2: 8000 * 0.1 = 800, Sum = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); + + // Check column statistics are preserved but marked as inexact + assert_eq!(result.column_statistics.len(), 2); + + // First column + assert_eq!( + result.column_statistics[0].null_count, + Precision::Inexact(10) + ); + assert_eq!( + result.column_statistics[0].max_value, + Precision::Inexact(ScalarValue::Int32(Some(100))) + ); + assert_eq!( + result.column_statistics[0].min_value, + Precision::Inexact(ScalarValue::Int32(Some(0))) + ); + assert_eq!( + result.column_statistics[0].sum_value, + Precision::Inexact(ScalarValue::Int32(Some(5050))) + ); + assert_eq!( + result.column_statistics[0].distinct_count, Precision::Inexact(50) ); @@ -1358,6 +2300,7 @@ mod tests { min_value: Precision::Inexact(ScalarValue::Int32(Some(0))), sum_value: Precision::Inexact(ScalarValue::Int32(Some(5050))), distinct_count: Precision::Inexact(50), + byte_size: Precision::Inexact(4000), }], }; @@ -1366,9 +2309,9 @@ mod tests { // Check num_rows is inexact assert_eq!(result.num_rows, Precision::Inexact(500)); - // Check total_byte_size is scaled and inexact - // 500/1000 = 0.5, so 8000 * 0.5 = 4000 - assert_eq!(result.total_byte_size, Precision::Inexact(4000)); + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.5 = 2000, Sum = 2000 + assert_eq!(result.total_byte_size, Precision::Inexact(2000)); // Column stats remain inexact assert_eq!( @@ -1425,8 +2368,8 @@ mod tests { .unwrap(); assert_eq!(result.num_rows, Precision::Exact(300)); - // 300/1000 = 0.3, so 8000 * 0.3 = 2400 - assert_eq!(result.total_byte_size, Precision::Inexact(2400)); + // Column 1: byte_size 800 * (300/500) = 240, Sum = 240 + assert_eq!(result.total_byte_size, Precision::Inexact(240)); } #[test] @@ -1442,8 +2385,8 @@ mod tests { let result = original_stats.clone().with_fetch(Some(100), 0, 4).unwrap(); assert_eq!(result.num_rows, Precision::Exact(400)); - // 400/1000 = 0.4, so 8000 * 0.4 = 3200 - assert_eq!(result.total_byte_size, Precision::Inexact(3200)); + // Column 1: byte_size 800 * 0.4 = 320, Sum = 320 + assert_eq!(result.total_byte_size, Precision::Inexact(320)); } #[test] @@ -1458,6 +2401,7 @@ mod tests { min_value: Precision::Absent, sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }], }; @@ -1496,6 +2440,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-100))), sum_value: Precision::Exact(ScalarValue::Int32(Some(123456))), distinct_count: Precision::Exact(789), + byte_size: Precision::Exact(4000), }; let original_stats = Statistics { @@ -1522,6 +2467,780 @@ mod tests { result_col_stats.sum_value, Precision::Inexact(ScalarValue::Int32(Some(123456))) ); - assert_eq!(result_col_stats.distinct_count, Precision::Inexact(789)); + // NDV is capped at the new row count (250) since 789 > 250 + assert_eq!(result_col_stats.distinct_count, Precision::Inexact(250)); + } + + #[test] + fn test_byte_size_to_inexact() { + let col_stats = ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(5000), + }; + + let inexact = col_stats.to_inexact(); + assert_eq!(inexact.byte_size, Precision::Inexact(5000)); + } + + #[test] + fn test_with_byte_size_builder() { + let col_stats = + ColumnStatistics::new_unknown().with_byte_size(Precision::Exact(8192)); + assert_eq!(col_stats.byte_size, Precision::Exact(8192)); + } + + #[test] + fn test_with_sum_value_builder_widens_small_integers() { + let col_stats = ColumnStatistics::new_unknown() + .with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(123)))); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::UInt64(Some(123))) + ); + } + + #[test] + fn test_with_fetch_scales_byte_size() { + // Test that byte_size is scaled by the row ratio in with_fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8000), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // byte_size should be scaled: 4000 * 0.1 = 400, 8000 * 0.1 = 800 + assert_eq!( + result.column_statistics[0].byte_size, + Precision::Inexact(400) + ); + assert_eq!( + result.column_statistics[1].byte_size, + Precision::Inexact(800) + ); + + // total_byte_size should be computed as sum of byte_size values: 400 + 800 = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); + } + + #[test] + fn test_with_fetch_total_byte_size_fallback() { + // Test that total_byte_size falls back to scaling when not all columns have byte_size + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, // One column has no byte_size + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // total_byte_size should fall back to scaling: 8000 * 0.1 = 800 + assert_eq!(result.total_byte_size, Precision::Inexact(800)); + } + + #[test] + fn test_with_fetch_caps_ndv_at_row_count() { + // NDV=500 but after LIMIT 10, NDV should be capped at 10 + let stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Inexact(500), + ..Default::default() + }], + }; + + let result = stats.with_fetch(Some(10), 0, 1).unwrap(); + assert_eq!(result.num_rows, Precision::Exact(10)); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(10) + ); + } + + #[test] + fn test_with_fetch_caps_ndv_with_skip() { + // 1000 rows, NDV=500, OFFSET 5 LIMIT 10 + // with_fetch computes num_rows = min(1000 - 5, 10) = 10 + // NDV should be capped at 10 + let stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Inexact(500), + ..Default::default() + }], + }; + + let result = stats.with_fetch(Some(10), 5, 1).unwrap(); + assert_eq!(result.num_rows, Precision::Exact(10)); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(10) + ); + } + + #[test] + fn test_with_fetch_caps_ndv_with_large_skip() { + // 1000 rows, NDV=500, OFFSET 995 LIMIT 100 + // with_fetch computes num_rows = min(1000 - 995, 100) = 5 + // NDV should be capped at 5 + let stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Inexact(500), + ..Default::default() + }], + }; + + let result = stats.with_fetch(Some(100), 995, 1).unwrap(); + assert_eq!(result.num_rows, Precision::Exact(5)); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(5) + ); + } + + #[test] + fn test_with_fetch_ndv_below_row_count_unchanged() { + // NDV=5 and LIMIT 10: NDV should stay at 5 + let stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Inexact(5), + ..Default::default() + }], + }; + + let result = stats.with_fetch(Some(10), 0, 1).unwrap(); + assert_eq!(result.num_rows, Precision::Exact(10)); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(5) + ); + } + + #[test] + fn test_try_merge_iter_basic() { + let schema = Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + ])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(200))), + min_value: Precision::Exact(ScalarValue::Int32(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int32(Some(180))), + min_value: Precision::Exact(ScalarValue::Int32(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(250)); + + let col1_stats = &summary_stats.column_statistics[0]; + assert_eq!(col1_stats.null_count, Precision::Exact(3)); + assert_eq!( + col1_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col1_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(-10))) + ); + assert_eq!( + col1_stats.sum_value, + Precision::Exact(ScalarValue::Int64(Some(1100))) + ); + + let col2_stats = &summary_stats.column_statistics[1]; + assert_eq!(col2_stats.null_count, Precision::Exact(5)); + assert_eq!( + col2_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col2_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(5))) + ); + assert_eq!( + col2_stats.sum_value, + Precision::Exact(ScalarValue::Int64(Some(2200))) + ); + } + + #[test] + fn test_try_merge_iter_mixed_precision() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(250)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-10))) + ); + // sum_value becomes Absent because stats2 has Absent sum + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let items: Vec<&Statistics> = vec![]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Absent); + assert_eq!(summary_stats.total_byte_size, Precision::Absent); + assert_eq!(summary_stats.column_statistics.len(), 1); + assert_eq!( + summary_stats.column_statistics[0].null_count, + Precision::Absent + ); + } + + #[test] + fn test_try_merge_iter_single_item() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Exact(10), + byte_size: Precision::Exact(40), + }], + }; + + let items = vec![&stats]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats, stats); + } + + #[test] + fn test_try_merge_iter_mismatched_columns() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics::default(); + let stats2 = + Statistics::default().add_column_statistics(ColumnStatistics::new_unknown()); + + let items = vec![&stats1, &stats2]; + let e = Statistics::try_merge_iter(items, &schema).unwrap_err(); + assert_contains!( + e.to_string(), + "Cannot merge statistics with different number of columns: 0 vs 1" + ); + } + + #[test] + fn test_try_merge_iter_three_items() { + // Verify that merging three items works correctly + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(500))), + distinct_count: Precision::Exact(8), + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(20), + total_byte_size: Precision::Exact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(1000))), + distinct_count: Precision::Exact(15), + byte_size: Precision::Exact(160), + }], + }; + + let stats3 = Statistics { + num_rows: Precision::Exact(30), + total_byte_size: Precision::Exact(300), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(150))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(2000))), + distinct_count: Precision::Exact(25), + byte_size: Precision::Exact(240), + }], + }; + + let items = vec![&stats1, &stats2, &stats3]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(60)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(600)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Exact(6)); + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Int64(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Int64(Some(1))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Int64(Some(3500))) + ); + assert_eq!(col_stats.byte_size, Precision::Exact(480)); + // Overlap-based NDV merge (pairwise left-to-right): + // stats1+stats2: [10,100]+[5,200] -> NDV=16, then +stats3: [5,200]+[1,150] -> NDV=29 + assert_eq!(col_stats.distinct_count, Precision::Inexact(29)); + } + + #[test] + fn test_try_merge_iter_float_types() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Float64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(99.9))), + min_value: Precision::Exact(ScalarValue::Float64(Some(1.1))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(500.5))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(200.0))), + min_value: Precision::Exact(ScalarValue::Float64(Some(0.5))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(1000.0))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Float64(Some(200.0))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Float64(Some(0.5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Float64(Some(1500.5))) + ); + } + + #[test] + fn test_try_merge_iter_string_types() { + let schema = + Arc::new(Schema::new(vec![Field::new("col1", DataType::Utf8, false)])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("dog".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("bat".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))) + ); + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_all_inexact() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(1), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(20), + total_byte_size: Precision::Inexact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(-5))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(30)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(300)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Inexact(ScalarValue::Int64(Some(1500))) + ); + } + + #[test] + fn test_precision_min_in_place() { + // Exact vs Exact: keeps the smaller + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Exact(20)); + assert_eq!(lhs, Precision::Exact(10)); + + let mut lhs = Precision::Exact(20); + precision_min(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Exact(10)); + + // Equal exact values + let mut lhs = Precision::Exact(5); + precision_min(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Exact(5)); + + // Mixed exact/inexact: result is Inexact with smaller value + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(10)); + + let mut lhs = Precision::Inexact(10); + precision_min(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Inexact(5)); + + // Inexact vs Inexact + let mut lhs = Precision::Inexact(30); + precision_min(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(20)); + + // Absent makes result Absent + let mut lhs = Precision::Exact(10); + precision_min(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); + + let mut lhs = Precision::::Absent; + precision_min(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_precision_max_in_place() { + // Exact vs Exact: keeps the larger + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Exact(20)); + assert_eq!(lhs, Precision::Exact(20)); + + let mut lhs = Precision::Exact(20); + precision_max(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Exact(20)); + + // Equal exact values + let mut lhs = Precision::Exact(5); + precision_max(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Exact(5)); + + // Mixed exact/inexact: result is Inexact with larger value + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Inexact(20)); + assert_eq!(lhs, Precision::Inexact(20)); + + let mut lhs = Precision::Inexact(10); + precision_max(&mut lhs, &Precision::Exact(5)); + assert_eq!(lhs, Precision::Inexact(10)); + + // Inexact vs Inexact + let mut lhs = Precision::Inexact(20); + precision_max(&mut lhs, &Precision::Inexact(30)); + assert_eq!(lhs, Precision::Inexact(30)); + + // Absent makes result Absent + let mut lhs = Precision::Exact(10); + precision_max(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); + + let mut lhs = Precision::::Absent; + precision_max(&mut lhs, &Precision::Exact(10)); + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_widens_int32() { + let mut value = Precision::Exact(ScalarValue::Int32(Some(42))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Exact(ScalarValue::Int64(Some(42)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_preserves_int64() { + // Int64 is already the sum type for Int64, no widening needed + let mut value = Precision::Exact(ScalarValue::Int64(Some(100))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Exact(ScalarValue::Int64(Some(100)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_inexact() { + let mut value = Precision::Inexact(ScalarValue::Int32(Some(42))); + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Inexact(ScalarValue::Int64(Some(42)))); + } + + #[test] + fn test_cast_sum_value_to_sum_type_in_place_absent() { + let mut value = Precision::::Absent; + cast_sum_value_to_sum_type_in_place(&mut value); + assert_eq!(value, Precision::Absent); + } + + #[test] + fn test_precision_add_for_sum_in_place_same_type() { + // Int64 + Int64: no widening needed, straight add + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Exact(ScalarValue::Int64(Some(20))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Exact(ScalarValue::Int64(Some(30)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_widens_rhs() { + // lhs is already Int64 (widened), rhs is Int32 -> gets cast to Int64 + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Exact(ScalarValue::Int32(Some(5))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Exact(ScalarValue::Int64(Some(15)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_inexact() { + let mut lhs = Precision::Inexact(ScalarValue::Int64(Some(10))); + let rhs = Precision::Inexact(ScalarValue::Int32(Some(5))); + precision_add_for_sum_in_place(&mut lhs, &rhs); + assert_eq!(lhs, Precision::Inexact(ScalarValue::Int64(Some(15)))); + } + + #[test] + fn test_precision_add_for_sum_in_place_absent_rhs() { + let mut lhs = Precision::Exact(ScalarValue::Int64(Some(10))); + precision_add_for_sum_in_place(&mut lhs, &Precision::Absent); + assert_eq!(lhs, Precision::Absent); } } diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index c51dea1c4de04..f060704944233 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -735,32 +735,34 @@ mod tests { let non_existing = cwd.join("non-existing-dir").display().to_string(); let non_existing_str = non_existing.as_str(); - env::set_var(udf_env, non_existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_err()); - - env::set_var(udf_env, ""); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, " "); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::remove_var(udf_env); - let res = get_data_dir(udf_env, non_existing_str); - assert!(res.is_err()); - - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + unsafe { + env::set_var(udf_env, non_existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_err()); + + env::set_var(udf_env, ""); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, " "); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::remove_var(udf_env); + let res = get_data_dir(udf_env, non_existing_str); + assert!(res.is_err()); + + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + } } #[test] diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 9b36266eec2e9..39300b9564621 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -796,7 +796,9 @@ pub trait TreeNodeContainer<'a, T: 'a>: Sized { ) -> Result>; } -impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box { +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Default> TreeNodeContainer<'a, T> + for Box +{ fn apply_elements Result>( &'a self, f: F, @@ -805,14 +807,24 @@ impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box } fn map_elements Result>>( - self, + mut self, f: F, ) -> Result> { - (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c))) + // Rewrite in place so the existing heap allocation can be reused. + // `mem::take` hands the inner `C` to `f` while leaving + // `C::default()` in the slot, so an unwinding drop finds a valid + // `C` even if `f` panics or the `?` short-circuits. + let inner = std::mem::take(&mut *self); + Ok(inner.map_elements(f)?.update_data(|c| { + *self = c; + self + })) } } -impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc { +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone + Default> TreeNodeContainer<'a, T> + for Arc +{ fn apply_elements Result>( &'a self, f: F, @@ -821,12 +833,18 @@ impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> fo } fn map_elements Result>>( - self, + mut self, f: F, ) -> Result> { - Arc::unwrap_or_clone(self) - .map_elements(f)? - .map_data(|c| Ok(Arc::new(c))) + // Rewrite in place using the same `mem::take` strategy as + // `Box::map_elements`. `Arc::make_mut` gives us exclusive + // access (cloning `C` first if we were sharing), after which + // `get_mut` is infallible. + let inner = std::mem::take(Arc::make_mut(&mut self)); + Ok(inner.map_elements(f)?.update_data(|c| { + *Arc::get_mut(&mut self).unwrap() = c; + self + })) } } @@ -956,12 +974,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeContainer<'a, T> for (C0, C1, C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2) { fn apply_elements Result>( &'a self, @@ -992,13 +1010,13 @@ impl< } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - C3: TreeNodeContainer<'a, T>, - > TreeNodeContainer<'a, T> for (C0, C1, C2, C3) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2, C3) { fn apply_elements Result>( &'a self, @@ -1090,12 +1108,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) { fn apply_ref_elements Result>( &self, @@ -1109,13 +1127,13 @@ impl< } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - C3: TreeNodeContainer<'a, T>, - > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) { fn apply_ref_elements Result>( &self, @@ -1335,14 +1353,15 @@ impl TreeNode for T { pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; + use std::sync::Arc; + use crate::Result; use crate::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; - use crate::Result; - #[derive(Debug, Eq, Hash, PartialEq, Clone)] + #[derive(Debug, Default, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { pub(crate) children: Vec>, pub(crate) data: T, @@ -2431,4 +2450,46 @@ pub(crate) mod tests { item.visit(&mut visitor).unwrap(); } + + #[test] + fn box_map_elements_reuses_allocation() { + let boxed = Box::new(TestTreeNode::new_leaf(42i32)); + let before: *const TestTreeNode = &*boxed; + let out = boxed.map_elements(|n| Ok(Transformed::no(n))).unwrap(); + let after: *const TestTreeNode = &*out.data; + assert_eq!(after, before); + } + + #[test] + fn arc_map_elements_reuses_allocation_when_unique() { + let arc = Arc::new(TestTreeNode::new_leaf(42i32)); + let before = Arc::as_ptr(&arc); + let out = arc.map_elements(|n| Ok(Transformed::no(n))).unwrap(); + assert_eq!(Arc::as_ptr(&out.data), before); + } + + #[test] + fn arc_map_elements_clones_when_shared() { + // When the input `Arc` is shared, `make_mut` clones into a fresh + // allocation, so the reuse optimization does not apply. + let arc = Arc::new(TestTreeNode::new_leaf(42i32)); + let _keepalive = Arc::clone(&arc); + let before = Arc::as_ptr(&arc); + let out = arc.map_elements(|n| Ok(Transformed::no(n))).unwrap(); + assert_ne!(Arc::as_ptr(&out.data), before); + } + + #[test] + fn box_map_elements_panic() { + use std::panic::{AssertUnwindSafe, catch_unwind}; + let boxed = Box::new(TestTreeNode::new_leaf(42i32)); + let result = catch_unwind(AssertUnwindSafe(|| { + boxed + .map_elements(|_: TestTreeNode| -> Result<_> { + panic!("simulated panic during rewrite") + }) + .ok() + })); + assert!(result.is_err()); + } } diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index 314529b99a342..dfd2cc4cf2d8b 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::IntervalUnit::*; +use arrow::datatypes::TimeUnit::*; use crate::types::{LogicalTypeRef, NativeType}; use std::sync::{Arc, LazyLock}; @@ -82,3 +83,17 @@ singleton_variant!( Interval, MonthDayNano ); + +singleton_variant!( + LOGICAL_INTERVAL_YEAR_MONTH, + logical_interval_year_month, + Interval, + YearMonth +); + +singleton_variant!( + LOGICAL_DURATION_MICROSECOND, + logical_duration_microsecond, + Duration, + Microsecond +); diff --git a/datafusion/common/src/types/canonical_extensions/bool8.rs b/datafusion/common/src/types/canonical_extensions/bool8.rs new file mode 100644 index 0000000000000..e0f7a5914a6ed --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/bool8.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::error::_internal_err; +use crate::types::extension::DFExtensionType; +use arrow::array::{Array, Int8Array}; +use arrow::datatypes::DataType; +use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use arrow_schema::extension::{Bool8, ExtensionType}; +use std::fmt::Write; + +/// Defines the extension type logic for the canonical `arrow.bool8` extension type. This extension +/// type allows storing a Boolean value in a single byte, instead of a single bit. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`Bool8`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFBool8(Bool8); + +impl DFBool8 { + /// Creates a new [`DFBool8`], validating that the storage type is compatible with the + /// extension type. + /// + /// Even though [`DFBool8`] only supports a single storage type ([`DataType::Int8`]), passing-in + /// the storage type allows conveniently validating whether this extension type is compatible + /// with a given [`DataType`]. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + // Validates the storage type + Ok(Self(::try_new( + data_type, metadata, + )?)) + } +} + +impl DFExtensionType for DFBool8 { + fn storage_type(&self) -> DataType { + DataType::Int8 + } + + fn serialize_metadata(&self) -> Option { + self.0.serialize_metadata() + } + + fn create_array_formatter<'fmt>( + &self, + array: &'fmt dyn Array, + options: &FormatOptions<'fmt>, + ) -> Result>> { + if array.data_type() != &DataType::Int8 { + return _internal_err!("Wrong array type for Bool8"); + } + + let display_index = Bool8ValueDisplayIndex { + array: array.as_any().downcast_ref().unwrap(), + null_str: options.null(), + }; + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } +} + +/// Pretty printer for binary bool8 values. +#[derive(Debug, Clone, Copy)] +struct Bool8ValueDisplayIndex<'a> { + array: &'a Int8Array, + null_str: &'a str, +} + +impl DisplayIndex for Bool8ValueDisplayIndex<'_> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + write!(f, "{}", self.null_str)?; + return Ok(()); + } + + let bytes = self.array.value(idx); + write!(f, "{}", bytes != 0)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_pretty_bool8() { + let values = Int8Array::from_iter([Some(0), Some(1), Some(-20), None]); + + let extension_type = DFBool8(Bool8 {}); + let formatter = extension_type + .create_array_formatter(&values, &FormatOptions::default().with_null("NULL")) + .unwrap() + .unwrap(); + + assert_eq!(formatter.value(0).to_string(), "false"); + assert_eq!(formatter.value(1).to_string(), "true"); + assert_eq!(formatter.value(2).to_string(), "true"); + assert_eq!(formatter.value(3).to_string(), "NULL"); + } +} diff --git a/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs b/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs new file mode 100644 index 0000000000000..9148d9a1b39f2 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::{ExtensionType, FixedShapeTensor}; + +/// Defines the extension type logic for the canonical `arrow.fixed_shape_tensor` extension type. +/// This extension type can be used to store a [tensor](https://en.wikipedia.org/wiki/Tensor) of +/// a fixed shape. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`FixedShapeTensor`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFFixedShapeTensor { + inner: FixedShapeTensor, + /// The storage type of the tensor. + /// + /// While we could reconstruct the storage type from the inner [`FixedShapeTensor`], we may + /// choose a different name for the field within the [`DataType::FixedSizeList`] which can + /// cause problems down the line (e.g., checking for equality). + storage_type: DataType, +} + +impl DFFixedShapeTensor { + /// Creates a new [`DFFixedShapeTensor`], validating that the storage type is compatible with + /// the extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self { + inner: ::try_new(data_type, metadata)?, + storage_type: data_type.clone(), + }) + } +} + +impl DFExtensionType for DFFixedShapeTensor { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + self.inner.serialize_metadata() + } +} diff --git a/datafusion/common/src/types/canonical_extensions/json.rs b/datafusion/common/src/types/canonical_extensions/json.rs new file mode 100644 index 0000000000000..8be9993a26061 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/json.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::{ExtensionType, Json}; + +/// Defines the extension type logic for the canonical `arrow.json` extension type. This extension +/// type defines that a particular string field stores JSON values. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`Json`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFJson { + inner: Json, + storage_type: DataType, +} + +impl DFJson { + /// Creates a new [`DFJson`], validating that the storage type is compatible with the + /// extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self { + inner: ::try_new(data_type, metadata)?, + storage_type: data_type.clone(), + }) + } +} + +impl DFExtensionType for DFJson { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + self.inner.serialize_metadata() + } +} diff --git a/datafusion/common/src/types/canonical_extensions/mod.rs b/datafusion/common/src/types/canonical_extensions/mod.rs new file mode 100644 index 0000000000000..2d74d0669d213 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/mod.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bool8; +mod fixed_shape_tensor; +mod json; +mod opaque; +mod timestamp_with_offset; +mod uuid; +mod variable_shape_tensor; + +pub use bool8::DFBool8; +pub use fixed_shape_tensor::DFFixedShapeTensor; +pub use json::DFJson; +pub use opaque::DFOpaque; +pub use timestamp_with_offset::DFTimestampWithOffset; +pub use uuid::DFUuid; +pub use variable_shape_tensor::DFVariableShapeTensor; diff --git a/datafusion/common/src/types/canonical_extensions/opaque.rs b/datafusion/common/src/types/canonical_extensions/opaque.rs new file mode 100644 index 0000000000000..d14f07737b6a7 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/opaque.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::{ExtensionType, Opaque}; + +/// Defines the extension type logic for the canonical `arrow.opaque` extension type. This extension +/// type represents types that DataFusion cannot interpret. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`Opaque`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFOpaque { + inner: Opaque, + storage_type: DataType, +} + +impl DFOpaque { + /// Creates a new [`DFOpaque`], validating that the storage type is compatible with the + /// extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self { + inner: ::try_new(data_type, metadata)?, + storage_type: data_type.clone(), + }) + } +} + +impl DFExtensionType for DFOpaque { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + self.inner.serialize_metadata() + } +} diff --git a/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs b/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs new file mode 100644 index 0000000000000..58a5fff9d0c28 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::ScalarValue; +use crate::error::_internal_err; +use crate::types::extension::DFExtensionType; +use arrow::array::{Array, AsArray, Int16Array}; +use arrow::buffer::NullBuffer; +use arrow::compute::cast; +use arrow::datatypes::{ + DataType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use arrow_schema::ArrowError; +use arrow_schema::extension::{ExtensionType, TimestampWithOffset}; +use std::fmt::Write; + +/// Defines the extension type logic for the canonical `arrow.timestamp_with_offset` extension type. +/// This extension type allows associating a different offset for each timestamp in a column. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`TimestampWithOffset`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFTimestampWithOffset { + inner: TimestampWithOffset, + storage_type: DataType, +} + +impl DFTimestampWithOffset { + /// Creates a new [`DFTimestampWithOffset`], validating that the storage type is compatible with + /// the extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self { + inner: ::try_new(data_type, metadata)?, + storage_type: data_type.clone(), + }) + } +} + +impl DFExtensionType for DFTimestampWithOffset { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + self.inner.serialize_metadata() + } + + fn create_array_formatter<'fmt>( + &self, + array: &'fmt dyn Array, + options: &FormatOptions<'fmt>, + ) -> Result>> { + if array.data_type() != &self.storage_type { + return _internal_err!( + "Unexpected data type for TimestampWithOffset: {}", + array.data_type() + ); + } + + let struct_array = array.as_struct(); + let timestamp_array = struct_array + .column_by_name("timestamp") + .expect("Type checked above") + .as_ref(); + let raw_offset_array = struct_array + .column_by_name("offset_minutes") + .expect("Type checked above"); + + // Get a regular [`Int16Array`], if the offset array is a dictionary or run-length encoded. + let offset_array = cast(&raw_offset_array, &DataType::Int16)? + .as_primitive() + .clone(); + + let display_index = TimestampWithOffsetDisplayIndex { + null_buffer: struct_array.nulls(), + timestamp_array, + offset_array, + options: options.clone(), + }; + + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } +} + +struct TimestampWithOffsetDisplayIndex<'a> { + /// The inner arrays are always non-null. Use the null buffer of the struct array to check + /// whether an element is null. + null_buffer: Option<&'a NullBuffer>, + timestamp_array: &'a dyn Array, + offset_array: Int16Array, + options: FormatOptions<'a>, +} + +impl DisplayIndex for TimestampWithOffsetDisplayIndex<'_> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.null_buffer.map(|nb| nb.is_null(idx)).unwrap_or(false) { + write!(f, "{}", self.options.null())?; + return Ok(()); + } + + let offset_minutes = self.offset_array.value(idx); + let offset = format_offset(offset_minutes); + + // The timestamp array must be UTC, so we can ignore the timezone. + let scalar = match self.timestamp_array.data_type() { + DataType::Timestamp(TimeUnit::Second, _) => { + let ts = self + .timestamp_array + .as_primitive::() + .value(idx); + ScalarValue::TimestampSecond(Some(ts), Some(offset.into())) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let ts = self + .timestamp_array + .as_primitive::() + .value(idx); + ScalarValue::TimestampMillisecond(Some(ts), Some(offset.into())) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let ts = self + .timestamp_array + .as_primitive::() + .value(idx); + ScalarValue::TimestampMicrosecond(Some(ts), Some(offset.into())) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let ts = self + .timestamp_array + .as_primitive::() + .value(idx); + ScalarValue::TimestampNanosecond(Some(ts), Some(offset.into())) + } + _ => unreachable!("TimestampWithOffset storage must be a Timestamp array"), + }; + + let array = scalar.to_array().map_err(|_| { + ArrowError::ComputeError("Failed to convert scalar to array".to_owned()) + })?; + let formatter = ArrayFormatter::try_new(&array, &self.options)?; + formatter.value(0).write(f)?; + + Ok(()) + } +} + +/// Formats the offset in the format `+/-HH:MM`, which can be used as an offset in the regular +/// timestamp types. +fn format_offset(minutes: i16) -> String { + let sign = if minutes >= 0 { '+' } else { '-' }; + let minutes = minutes.abs(); + let hours = minutes / 60; + let minutes = minutes % 60; + format!("{sign}{hours:02}:{minutes:02}") +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Array, DictionaryArray, Int16Array, Int32Array, RunArray, StructArray, + TimestampSecondArray, + }; + use arrow::buffer::NullBuffer; + use arrow::datatypes::{Field, Fields, Int16Type, Int32Type}; + use chrono::{TimeZone, Utc}; + use std::sync::Arc; + + #[test] + fn test_pretty_print_timestamp_with_offset() -> Result<(), ArrowError> { + let ts = Utc + .with_ymd_and_hms(2024, 4, 1, 0, 0, 0) + .unwrap() + .timestamp(); + + let offset_array = Arc::new(Int16Array::from(vec![60, -105, 0])); + + run_formatting_test( + vec![ts, ts, ts], + offset_array, + Some(NullBuffer::from(vec![true, true, false])), + FormatOptions::default().with_null("NULL"), + &[ + "2024-04-01T01:00:00+01:00", + "2024-03-31T22:15:00-01:45", + "NULL", + ], + ) + } + + #[test] + fn test_pretty_print_dictionary_offset() -> Result<(), ArrowError> { + let ts = Utc + .with_ymd_and_hms(2024, 4, 1, 12, 0, 0) + .unwrap() + .timestamp(); + + let offset_array = Arc::new(DictionaryArray::::new( + Int16Array::from(vec![0, 1, 0]), + Arc::new(Int16Array::from(vec![60, -60])), + )); + + run_formatting_test( + vec![ts, ts, ts], + offset_array, + None, + FormatOptions::default(), + &[ + "2024-04-01T13:00:00+01:00", + "2024-04-01T11:00:00-01:00", + "2024-04-01T13:00:00+01:00", + ], + ) + } + + #[test] + fn test_pretty_print_rle_offset() -> Result<(), ArrowError> { + let ts = Utc + .with_ymd_and_hms(2024, 4, 1, 12, 0, 0) + .unwrap() + .timestamp(); + + let run_ends = Int32Array::from(vec![2]); + let values = Int16Array::from(vec![120]); + let offset_array = Arc::new(RunArray::::try_new(&run_ends, &values)?); + + run_formatting_test( + vec![ts, ts], + offset_array, + None, + FormatOptions::default(), + &["2024-04-01T14:00:00+02:00", "2024-04-01T14:00:00+02:00"], + ) + } + + /// Create valid fields with flexible offset types + fn create_fields_custom_offset(time_unit: TimeUnit, offset_type: DataType) -> Fields { + let ts_field = Field::new( + "timestamp", + DataType::Timestamp(time_unit, Some("UTC".into())), + false, + ); + let offset_field = Field::new("offset_minutes", offset_type, false); + Fields::from(vec![ts_field, offset_field]) + } + + /// Helper to construct the arrays, run the formatter, and assert the expected strings. + fn run_formatting_test( + timestamps: Vec, + offset_array: Arc, + null_buffer: Option, + options: FormatOptions, + expected: &[&str], + ) -> Result<(), ArrowError> { + let fields = create_fields_custom_offset( + TimeUnit::Second, + offset_array.data_type().clone(), + ); + + let struct_array = StructArray::try_new( + fields, + vec![ + Arc::new(TimestampSecondArray::from(timestamps).with_timezone("UTC")), + offset_array, + ], + null_buffer, + )?; + + let formatter = DFTimestampWithOffset::try_new(struct_array.data_type(), ())? + .create_array_formatter(&struct_array, &options)? + .unwrap(); + + for (i, expected_str) in expected.iter().enumerate() { + assert_eq!(formatter.value(i).to_string(), *expected_str); + } + + Ok(()) + } +} diff --git a/datafusion/common/src/types/canonical_extensions/uuid.rs b/datafusion/common/src/types/canonical_extensions/uuid.rs new file mode 100644 index 0000000000000..8cbcf3f58a80e --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/uuid.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::error::_internal_err; +use crate::types::extension::DFExtensionType; +use arrow::array::{Array, FixedSizeBinaryArray}; +use arrow::datatypes::DataType; +use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use arrow_schema::extension::{ExtensionType, Uuid}; +use std::fmt::Write; +use uuid::Bytes; + +/// Defines the extension type logic for the canonical `arrow.uuid` extension type. This extension +/// type defines that a field should be interpreted as a +/// [UUID](https://de.wikipedia.org/wiki/Universally_Unique_Identifier). +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`Uuid`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFUuid(Uuid); + +impl DFUuid { + /// Creates a new [`DFUuid`], validating that the storage type is compatible with the + /// extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self(::try_new(data_type, metadata)?)) + } +} + +impl DFExtensionType for DFUuid { + fn storage_type(&self) -> DataType { + DataType::FixedSizeBinary(16) + } + + fn serialize_metadata(&self) -> Option { + self.0.serialize_metadata() + } + + fn create_array_formatter<'fmt>( + &self, + array: &'fmt dyn Array, + options: &FormatOptions<'fmt>, + ) -> Result>> { + if array.data_type() != &DataType::FixedSizeBinary(16) { + return _internal_err!("Wrong array type for Uuid"); + } + + let display_index = UuidValueDisplayIndex { + array: array.as_any().downcast_ref().unwrap(), + null_str: options.null(), + }; + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } +} + +/// Pretty printer for binary UUID values. +#[derive(Debug, Clone, Copy)] +struct UuidValueDisplayIndex<'a> { + array: &'a FixedSizeBinaryArray, + null_str: &'a str, +} + +impl DisplayIndex for UuidValueDisplayIndex<'_> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + write!(f, "{}", self.null_str)?; + return Ok(()); + } + + let bytes = Bytes::try_from(self.array.value(idx)) + .expect("FixedSizeBinaryArray length checked in create_array_formatter"); + let uuid = uuid::Uuid::from_bytes(bytes); + write!(f, "{uuid}")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ScalarValue; + use arrow_schema::ArrowError; + + #[test] + pub fn test_pretty_print_uuid() -> Result<(), ArrowError> { + let my_uuid = uuid::Uuid::nil(); + let uuid = ScalarValue::FixedSizeBinary(16, Some(my_uuid.as_bytes().to_vec())) + .to_array_of_size(1)?; + + let formatter = DFUuid::try_new(uuid.data_type(), ())? + .create_array_formatter(uuid.as_ref(), &FormatOptions::default())? + .unwrap(); + + assert_eq!( + formatter.value(0).to_string(), + "00000000-0000-0000-0000-000000000000" + ); + + Ok(()) + } +} diff --git a/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs b/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs new file mode 100644 index 0000000000000..00f59c70160e5 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::{ExtensionType, VariableShapeTensor}; + +/// Defines the extension type logic for the canonical `arrow.variable_shape_tensor` extension type. +/// This extension type can be used to store a [tensor](https://en.wikipedia.org/wiki/Tensor) with +/// variable shape that can change for each element. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also +/// [`VariableShapeTensor`] for the implementation of arrow-rs, which this type uses internally. +/// +/// +#[derive(Debug, Clone)] +pub struct DFVariableShapeTensor { + inner: VariableShapeTensor, + /// While we could reconstruct the storage type from the inner [`VariableShapeTensor`], we may + /// choose a different name for the field within the [`DataType::List`] which can cause problems + /// down the line (e.g., checking for equality). + storage_type: DataType, +} + +impl DFVariableShapeTensor { + /// Creates a new [`DFVariableShapeTensor`], validating that the storage type is compatible with + /// the extension type. + pub fn try_new( + data_type: &DataType, + metadata: ::Metadata, + ) -> Result { + Ok(Self { + inner: ::try_new(data_type, metadata)?, + storage_type: data_type.clone(), + }) + } +} + +impl DFExtensionType for DFVariableShapeTensor { + fn storage_type(&self) -> DataType { + self.storage_type.clone() + } + + fn serialize_metadata(&self) -> Option { + self.inner.serialize_metadata() + } +} diff --git a/datafusion/common/src/types/extension.rs b/datafusion/common/src/types/extension.rs new file mode 100644 index 0000000000000..3bcb533dbf9e6 --- /dev/null +++ b/datafusion/common/src/types/extension.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::Result; +use arrow::array::Array; +use arrow::util::display::{ArrayFormatter, FormatOptions}; +use arrow_schema::DataType; +use std::fmt::Debug; +use std::sync::Arc; + +/// A cheaply cloneable pointer to a [`DFExtensionType`]. +pub type DFExtensionTypeRef = Arc; + +/// Represents an implementation of a DataFusion extension type, including the storage [`DataType`]. +/// While, in general, an extension type can support several different storage types, a specific +/// instance of it is always locked into just one exact storage type and metadata pairing. +/// +/// This trait allows users to customize the behavior of DataFusion for certain types. Having this +/// ability is necessary because extension types affect how columns should be treated by the query +/// engine. This effect includes which operations are possible on a column and what are the expected +/// results from these operations. The extension type mechanism allows users to define how these +/// operations apply to a particular extension type. +/// +/// For example, adding two values of [`Int64`](arrow::datatypes::DataType::Int64) is a sensible +/// thing to do. However, if the same column is annotated with an extension type like `custom.id`, +/// the correct interpretation of a column changes. Adding together two `custom.id` values, even +/// though they are stored as integers, may no longer make sense. +/// +/// Note that DataFusion's extension type support is still young and therefore might not cover all +/// relevant use cases. Currently, the following operations can be customized: +/// - Pretty-printing values in record batches +/// +/// # Relation to Arrow's [`ExtensionType`](arrow_schema::extension::ExtensionType) +/// +/// The purpose of Arrow's [`ExtensionType`](arrow_schema::extension::ExtensionType) trait, for the +/// time being, is to allow reading and writing extension type metadata in a type-safe manner. The +/// trait does not provide any customization options. Therefore, downstream users (such as +/// DataFusion) have the flexibility to implement the extension type mechanism according to their +/// needs. [`DFExtensionType`] is DataFusion's implementation of this extension type mechanism. +/// +/// Furthermore, the current trait in arrow-rs is not dyn-compatible, which we need for implementing +/// extension type registries. In the future, the two implementations may increasingly converge. +/// +/// Another difference is that [`DFExtensionType`] represents a fully resolved extension type that +/// has a fixed storage type (i.e., [`DataType`]). This is different from arrow-rs, which only +/// stores the extension type's metadata. For example, an instance of DataFusion's JSON extension +/// type fixes one of the three possible storage types: [`DataType::Utf8`], +/// [`DataType::LargeUtf8`], or [`DataType::Utf8View`]. This fixed storaga type is returned in +/// [`DFExtensionType::storage_type`]. This is not possible in arrow-rs' extension type instances. +/// This is the reason why we have different types in DataFusion that usually delegate the metadata +/// processing to the underlying arrow-rs extension type instance +/// (e.g., [`DFJson`](crate::types::DFJson) instead of [`Json`](arrow_schema::extension::Json)). +/// +/// # Examples +/// +/// Examples for using the extension type machinery can be found in the DataFusion examples +/// directory. +pub trait DFExtensionType: Debug + Send + Sync { + /// Returns the underlying storage type. + fn storage_type(&self) -> DataType; + + /// Returns the serialized metadata. + fn serialize_metadata(&self) -> Option; + + /// Returns an [`ArrayFormatter`] that can format values of this type. + /// + /// If `Ok(None)` is returned, the default implementation will be used. + /// If an error is returned, there was an error creating the formatter. + fn create_array_formatter<'fmt>( + &self, + _array: &'fmt dyn Array, + _options: &FormatOptions<'fmt>, + ) -> Result>> { + Ok(None) + } +} diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index 674b1a41204d1..f11f1b47b16d3 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -100,7 +100,10 @@ impl fmt::Debug for dyn LogicalType { impl std::fmt::Display for dyn LogicalType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") + match self.signature() { + TypeSignature::Native(_) => write!(f, "{}", self.native()), + TypeSignature::Extension { name, .. } => write!(f, "{name}"), + } } } @@ -132,3 +135,118 @@ impl Hash for dyn LogicalType { self.signature().hash(state); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{ + LogicalField, LogicalFields, logical_boolean, logical_date, logical_float32, + logical_float64, logical_int32, logical_int64, logical_null, logical_string, + }; + use arrow::datatypes::{Field, Fields}; + use insta::assert_snapshot; + + #[test] + fn test_logical_type_display_simple() { + assert_snapshot!(logical_null(), @"Null"); + assert_snapshot!(logical_boolean(), @"Boolean"); + assert_snapshot!(logical_int32(), @"Int32"); + assert_snapshot!(logical_int64(), @"Int64"); + assert_snapshot!(logical_float32(), @"Float32"); + assert_snapshot!(logical_float64(), @"Float64"); + assert_snapshot!(logical_string(), @"String"); + assert_snapshot!(logical_date(), @"Date"); + } + + #[test] + fn test_logical_type_display_list() { + let list_type: Arc = Arc::new(NativeType::List(Arc::new( + LogicalField::from(&Field::new("item", DataType::Int32, true)), + ))); + assert_snapshot!(list_type, @"List(Int32)"); + } + + #[test] + fn test_logical_type_display_struct() { + let struct_type: Arc = Arc::new(NativeType::Struct( + LogicalFields::from(&Fields::from(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, true), + ])), + )); + assert_snapshot!(struct_type, @r#"Struct("x": non-null Float64, "y": Float64)"#); + } + + #[test] + fn test_logical_type_display_fixed_size_list() { + let fsl_type: Arc = Arc::new(NativeType::FixedSizeList( + Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Float32, + false, + ))), + 3, + )); + assert_snapshot!(fsl_type, @"FixedSizeList(3 x non-null Float32)"); + } + + #[test] + fn test_logical_type_display_map() { + let map_type: Arc = Arc::new(NativeType::Map(Arc::new( + LogicalField::from(&Field::new("entries", DataType::Utf8, false)), + ))); + assert_snapshot!(map_type, @"Map(non-null String)"); + } + + #[test] + fn test_logical_type_display_union() { + use arrow::datatypes::UnionFields; + + let union_fields = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("int_val", DataType::Int32, false), + Field::new("str_val", DataType::Utf8, true), + ], + ) + .unwrap(); + let union_type: Arc = Arc::new(NativeType::Union( + crate::types::LogicalUnionFields::from(&union_fields), + )); + assert_snapshot!(union_type, @r#"Union(0: ("int_val": non-null Int32), 1: ("str_val": String))"#); + } + + #[test] + fn test_logical_type_display_nullable_vs_non_nullable() { + let nullable_list: Arc = Arc::new(NativeType::List(Arc::new( + LogicalField::from(&Field::new("item", DataType::Int32, true)), + ))); + let non_nullable_list: Arc = + Arc::new(NativeType::List(Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Int32, + false, + ))))); + + assert_snapshot!(nullable_list, @"List(Int32)"); + assert_snapshot!(non_nullable_list, @"List(non-null Int32)"); + } + + #[test] + fn test_logical_type_display_extension() { + struct JsonType; + impl LogicalType for JsonType { + fn native(&self) -> &NativeType { + &NativeType::String + } + fn signature(&self) -> TypeSignature<'_> { + TypeSignature::Extension { + name: "JSON", + parameters: &[], + } + } + } + let json: Arc = Arc::new(JsonType); + assert_snapshot!(json, @"JSON"); + } +} diff --git a/datafusion/common/src/types/mod.rs b/datafusion/common/src/types/mod.rs index 2f9ce4ce02827..57bf921a6d564 100644 --- a/datafusion/common/src/types/mod.rs +++ b/datafusion/common/src/types/mod.rs @@ -16,11 +16,15 @@ // under the License. mod builtin; +mod canonical_extensions; +mod extension; mod field; mod logical; mod native; pub use builtin::*; +pub use canonical_extensions::*; +pub use extension::*; pub use field::*; pub use logical::*; pub use native::*; diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 8c41701ae5768..580d572af4c0f 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -19,11 +19,11 @@ use super::{ LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, }; -use crate::error::{Result, _internal_err}; +use crate::error::{_internal_err, Result}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, - DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType, + Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::{fmt::Display, sync::Arc}; @@ -184,9 +184,82 @@ pub enum NativeType { Map(LogicalFieldRef), } +/// Format a [`LogicalField`] for display, matching [`arrow::datatypes::DataType`]'s +/// Display convention of showing a `"non-null "` prefix for non-nullable fields. +fn format_logical_field( + f: &mut std::fmt::Formatter<'_>, + field: &LogicalField, +) -> std::fmt::Result { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "{:?}: {non_null}{}", field.name, field.logical_type) +} + impl Display for NativeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") // TODO: nicer formatting + // Match the format used by arrow::datatypes::DataType's Display impl + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int8 => write!(f, "Int8"), + Self::Int16 => write!(f, "Int16"), + Self::Int32 => write!(f, "Int32"), + Self::Int64 => write!(f, "Int64"), + Self::UInt8 => write!(f, "UInt8"), + Self::UInt16 => write!(f, "UInt16"), + Self::UInt32 => write!(f, "UInt32"), + Self::UInt64 => write!(f, "UInt64"), + Self::Float16 => write!(f, "Float16"), + Self::Float32 => write!(f, "Float32"), + Self::Float64 => write!(f, "Float64"), + Self::Timestamp(unit, Some(tz)) => write!(f, "Timestamp({unit}, {tz:?})"), + Self::Timestamp(unit, None) => write!(f, "Timestamp({unit})"), + Self::Date => write!(f, "Date"), + Self::Time(unit) => write!(f, "Time({unit})"), + Self::Duration(unit) => write!(f, "Duration({unit})"), + Self::Interval(unit) => write!(f, "Interval({unit:?})"), + Self::Binary => write!(f, "Binary"), + Self::FixedSizeBinary(size) => write!(f, "FixedSizeBinary({size})"), + Self::String => write!(f, "String"), + Self::List(field) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "List({non_null}{})", field.logical_type) + } + Self::FixedSizeList(field, size) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!( + f, + "FixedSizeList({size} x {non_null}{})", + field.logical_type + ) + } + Self::Struct(fields) => { + write!(f, "Struct(")?; + for (i, field) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + format_logical_field(f, field)?; + } + write!(f, ")") + } + Self::Union(fields) => { + write!(f, "Union(")?; + for (i, (type_id, field)) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{type_id}: (")?; + format_logical_field(f, field)?; + write!(f, ")")?; + } + write!(f, ")") + } + Self::Decimal(precision, scale) => write!(f, "Decimal({precision}, {scale})"), + Self::Map(field) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "Map({non_null}{})", field.logical_type) + } + } } } @@ -241,9 +314,7 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), // If given type is Date, return the same type - (Self::Date, origin) if matches!(origin, Date32 | Date64) => { - origin.to_owned() - } + (Self::Date, Date32 | Date64) => origin.to_owned(), (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), @@ -253,6 +324,8 @@ impl LogicalType for NativeType { (Self::Interval(iu), _) => Interval(*iu), (Self::Binary, LargeUtf8) => LargeBinary, (Self::Binary, Utf8View) => BinaryView, + // We don't cast to another kind of binary type if the origin one is already a binary type + (Self::Binary, Binary | LargeBinary | BinaryView) => origin.to_owned(), (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => { BinaryView } @@ -364,7 +437,7 @@ impl LogicalType for NativeType { "Unavailable default cast for native type {} from physical type {}", self, origin - ) + ); } }) } @@ -430,22 +503,7 @@ impl From for NativeType { impl NativeType { #[inline] pub fn is_numeric(&self) -> bool { - use NativeType::*; - matches!( - self, - UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float16 - | Float32 - | Float64 - | Decimal(_, _) - ) + self.is_integer() || self.is_float() || self.is_decimal() } #[inline] @@ -464,7 +522,7 @@ impl NativeType { #[inline] pub fn is_date(&self) -> bool { - matches!(self, NativeType::Date) + *self == NativeType::Date } #[inline] @@ -489,6 +547,102 @@ impl NativeType { #[inline] pub fn is_null(&self) -> bool { - matches!(self, NativeType::Null) + *self == NativeType::Null + } + + #[inline] + pub fn is_decimal(&self) -> bool { + matches!(self, Self::Decimal(_, _)) + } + + #[inline] + pub fn is_float(&self) -> bool { + matches!(self, Self::Float16 | Self::Float32 | Self::Float64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::LogicalField; + use insta::assert_snapshot; + + #[test] + fn test_native_type_display() { + assert_snapshot!(NativeType::Null, @"Null"); + assert_snapshot!(NativeType::Boolean, @"Boolean"); + assert_snapshot!(NativeType::Int8, @"Int8"); + assert_snapshot!(NativeType::Int16, @"Int16"); + assert_snapshot!(NativeType::Int32, @"Int32"); + assert_snapshot!(NativeType::Int64, @"Int64"); + assert_snapshot!(NativeType::UInt8, @"UInt8"); + assert_snapshot!(NativeType::UInt16, @"UInt16"); + assert_snapshot!(NativeType::UInt32, @"UInt32"); + assert_snapshot!(NativeType::UInt64, @"UInt64"); + assert_snapshot!(NativeType::Float16, @"Float16"); + assert_snapshot!(NativeType::Float32, @"Float32"); + assert_snapshot!(NativeType::Float64, @"Float64"); + assert_snapshot!(NativeType::Date, @"Date"); + assert_snapshot!(NativeType::Binary, @"Binary"); + assert_snapshot!(NativeType::String, @"String"); + assert_snapshot!(NativeType::FixedSizeBinary(16), @"FixedSizeBinary(16)"); + assert_snapshot!(NativeType::Decimal(10, 2), @"Decimal(10, 2)"); + } + + #[test] + fn test_native_type_display_timestamp() { + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Second, None), + @"Timestamp(s)" + ); + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Millisecond, None), + @"Timestamp(ms)" + ); + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))), + @r#"Timestamp(ns, "UTC")"# + ); + } + + #[test] + fn test_native_type_display_time_duration_interval() { + assert_snapshot!(NativeType::Time(TimeUnit::Microsecond), @"Time(µs)"); + assert_snapshot!(NativeType::Duration(TimeUnit::Nanosecond), @"Duration(ns)"); + assert_snapshot!(NativeType::Interval(IntervalUnit::YearMonth), @"Interval(YearMonth)"); + assert_snapshot!(NativeType::Interval(IntervalUnit::MonthDayNano), @"Interval(MonthDayNano)"); + } + + #[test] + fn test_native_type_display_nested() { + let list = NativeType::List(Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Int32, + true, + )))); + assert_snapshot!(list, @"List(Int32)"); + + let fixed_list = NativeType::FixedSizeList( + Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Float64, + false, + ))), + 3, + ); + assert_snapshot!(fixed_list, @"FixedSizeList(3 x non-null Float64)"); + + let struct_type = NativeType::Struct(LogicalFields::from(&Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, true), + ]))); + assert_snapshot!(struct_type, @r#"Struct("name": non-null String, "age": Int32)"#); + + let map = NativeType::Map(Arc::new(LogicalField::from(&Field::new( + "entries", + DataType::Utf8, + false, + )))); + assert_snapshot!(map, @"Map(non-null String)"); } } diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs new file mode 100644 index 0000000000000..783ec665f3355 --- /dev/null +++ b/datafusion/common/src/utils/aggregate.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Scalar-level aggregation utilities for statistics merging. +//! +//! Provides in-place accumulation helpers that reuse the existing +//! [`ScalarValue`] accumulator when possible. + +use crate::stats::Precision; +use crate::{Result, ScalarValue}; + +/// Adds `rhs` into `lhs`, mutating the accumulator in place when +/// possible and otherwise falling back to `ScalarValue::add_checked`. +pub(crate) fn scalar_add(lhs: &mut ScalarValue, rhs: &ScalarValue) -> Result<()> { + if lhs.try_add_checked_in_place(rhs)? { + return Ok(()); + } + + *lhs = lhs.add_checked(rhs)?; + Ok(()) +} + +/// [`Precision`]-aware sum that mutates `lhs` in place when possible. +/// +/// Mirrors the semantics of `Precision::add`, including +/// checked overflow handling, but avoids allocating a fresh +/// [`ScalarValue`] for the common numeric fast path. +pub(crate) fn precision_add( + lhs: &mut Precision, + rhs: &Precision, +) { + let (mut lhs_value, lhs_is_exact) = match std::mem::take(lhs) { + Precision::Exact(value) => (value, true), + Precision::Inexact(value) => (value, false), + Precision::Absent => return, + }; + + let (rhs_value, rhs_is_exact) = match rhs { + Precision::Exact(value) => (value, true), + Precision::Inexact(value) => (value, false), + Precision::Absent => return, + }; + + if scalar_add(&mut lhs_value, rhs_value).is_err() { + return; + } + + *lhs = if lhs_is_exact && rhs_is_exact { + Precision::Exact(lhs_value) + } else { + Precision::Inexact(lhs_value) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scalar_add_null_propagates() -> Result<()> { + let mut lhs = ScalarValue::Int32(Some(42)); + + scalar_add(&mut lhs, &ScalarValue::Int32(None))?; + + assert_eq!(lhs, ScalarValue::Int32(None)); + Ok(()) + } + + #[test] + fn test_scalar_add_overflow_returns_error() { + let mut lhs = ScalarValue::Int32(Some(i32::MAX)); + + let err = scalar_add(&mut lhs, &ScalarValue::Int32(Some(1))) + .unwrap_err() + .strip_backtrace(); + + assert_eq!( + err, + "Arrow error: Arithmetic overflow: Overflow happened on: 2147483647 + 1" + ); + } + + #[test] + fn test_precision_add_null_propagates() { + let mut lhs = Precision::Exact(ScalarValue::Int32(Some(42))); + + precision_add(&mut lhs, &Precision::Exact(ScalarValue::Int32(None))); + + assert_eq!(lhs, Precision::Exact(ScalarValue::Int32(None))); + } + + #[test] + fn test_precision_add_overflow_becomes_absent() { + let mut lhs = Precision::Exact(ScalarValue::Int32(Some(i32::MAX))); + + precision_add(&mut lhs, &Precision::Exact(ScalarValue::Int32(Some(1)))); + + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_precision_add_rhs_absent_absorbs() { + let mut lhs = Precision::Exact(ScalarValue::Int32(Some(42))); + + precision_add(&mut lhs, &Precision::Absent); + + assert_eq!(lhs, Precision::Absent); + } + + #[test] + fn test_precision_add_mixed_exactness() { + let mut lhs = Precision::Exact(ScalarValue::Int32(Some(10))); + + precision_add(&mut lhs, &Precision::Inexact(ScalarValue::Int32(Some(5)))); + + assert_eq!(lhs, Precision::Inexact(ScalarValue::Int32(Some(15)))); + } +} diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index a56b940fab666..78ec434d2b577 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -18,8 +18,10 @@ //! This module provides a function to estimate the memory size of a HashTable prior to allocation use crate::error::_exec_datafusion_err; -use crate::Result; -use std::mem::size_of; +use crate::{HashSet, Result}; +use arrow::array::ArrayData; +use arrow::record_batch::RecordBatch; +use std::{mem::size_of, ptr::NonNull}; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -99,6 +101,74 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result }) } +/// Calculate total used memory of this batch. +/// +/// This function is used to estimate the physical memory usage of the `RecordBatch`. +/// It only counts the memory of large data `Buffer`s, and ignores metadata like +/// types and pointers. +/// The implementation will add up all unique `Buffer`'s memory +/// size, due to: +/// - The data pointer inside `Buffer` are memory regions returned by global memory +/// allocator, those regions can't have overlap. +/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap +/// or reuse the same `Buffer`. For example: taking a slice from `Array`. +/// +/// Example: +/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing +/// to a sub-region of the same buffer. +/// +/// {xxxxxxxxxxxxxxxxxxx} <--- buffer +/// ^ ^ ^ ^ +/// | | | | +/// col1->{ } | | +/// col2--------->{ } +/// +/// In the above case, `get_record_batch_memory_size` will return the size of +/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size. +/// +/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the +/// buffer memory size if multiple arrays within the batch are sharing the same +/// `Buffer`. This method provides temporary fix until the issue is resolved: +/// +pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize { + // Store pointers to `Buffer`'s start memory address (instead of actual + // used data region's pointer represented by current `Array`) + let mut counted_buffers: HashSet> = HashSet::new(); + let mut total_size = 0; + + for array in batch.columns() { + let array_data = array.to_data(); + count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size); + } + + total_size +} + +/// Count the memory usage of `array_data` and its children recursively. +fn count_array_data_memory_size( + array_data: &ArrayData, + counted_buffers: &mut HashSet>, + total_size: &mut usize, +) { + // Count memory usage for `array_data` + for buffer in array_data.buffers() { + if counted_buffers.insert(buffer.data_ptr()) { + *total_size += buffer.capacity(); + } // Otherwise the buffer's memory is already counted + } + + if let Some(null_buffer) = array_data.nulls() + && counted_buffers.insert(null_buffer.inner().inner().data_ptr()) + { + *total_size += null_buffer.inner().inner().capacity(); + } + + // Count all children `ArrayData` recursively + for child in array_data.child_data() { + count_array_data_memory_size(child, counted_buffers, total_size); + } +} + #[cfg(test)] mod tests { use std::{collections::HashSet, mem::size_of}; @@ -132,3 +202,129 @@ mod tests { assert!(estimated.is_err()); } } + +#[cfg(test)] +mod record_batch_tests { + use super::*; + use arrow::array::{Float64Array, Int32Array, ListArray}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use std::sync::Arc; + + #[test] + fn test_get_record_batch_memory_size() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = + Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 60); + } + + #[test] + fn test_get_record_batch_memory_size_with_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = Int32Array::from(vec![None, Some(2), Some(3)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 100); + } + + #[test] + fn test_get_record_batch_memory_size_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "ints", + DataType::Int32, + false, + )])); + + let int_array: Int32Array = Int32Array::from(vec![] as Vec); + let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 0, "Empty batch should have 0 memory size"); + } + + #[test] + fn test_get_record_batch_memory_size_shared_buffer() { + let original = Int32Array::from(vec![1, 2, 3, 4, 5]); + let slice1 = original.slice(0, 3); + let slice2 = original.slice(2, 3); + + let schema_origin = Arc::new(Schema::new(vec![Field::new( + "origin_col", + DataType::Int32, + false, + )])); + let batch_origin = + RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("slice1", DataType::Int32, false), + Field::new("slice2", DataType::Int32, false), + ])); + + let batch_sliced = + RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)]) + .unwrap(); + + let size_origin = get_record_batch_memory_size(&batch_origin); + let size_sliced = get_record_batch_memory_size(&batch_sliced); + + assert_eq!(size_origin, size_sliced); + } + + #[test] + fn test_get_record_batch_memory_size_nested_array() { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "nested_int", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + Field::new( + "nested_int2", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + ])); + + let int_list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]); + + let int_list_array2 = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(4), Some(5), Some(6)]), + ]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_list_array), Arc::new(int_list_array2)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 8208); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21d..8c88be03fd5c8 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,28 +17,36 @@ //! This module provides the bisect function, which implements binary search. +pub(crate) mod aggregate; pub mod expr; pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::assert_or_internal_err; +use crate::error::{_exec_datafusion_err, _exec_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, - OffsetSizeTrait, + Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, + cast::AsArray, }; -use arrow::buffer::OffsetBuffer; -use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::array::{ + Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, make_array, +}; +use arrow::array::{LargeListViewArray, ListViewArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::compute::kernels::cmp::neq; +use arrow::compute::kernels::length::length; +use arrow::compute::{SortColumn, SortOptions, partition}; use arrow::datatypes::{DataType, Field, SchemaRef}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::HashSet; use std::num::NonZero; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::thread::available_parallelism; /// Applies an optional projection to a [`SchemaRef`], returning the @@ -69,10 +77,10 @@ use std::thread::available_parallelism; /// ``` pub fn project_schema( schema: &SchemaRef, - projection: Option<&Vec>, + projection: Option<&impl AsRef<[usize]>>, ) -> Result { let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), + Some(columns) => Arc::new(schema.project(columns.as_ref())?), None => Arc::clone(schema), }; Ok(schema) @@ -265,10 +273,10 @@ fn needs_quotes(s: &str) -> bool { let mut chars = s.chars(); // first char can not be a number unless escaped - if let Some(first_char) = chars.next() { - if !(first_char.is_ascii_lowercase() || first_char == '_') { - return true; - } + if let Some(first_char) = chars.next() + && !(first_char.is_ascii_lowercase() || first_char == '_') + { + return true; } !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') @@ -478,6 +486,34 @@ impl SingleRowListArrayBuilder { ScalarValue::FixedSizeList(Arc::new(self.build_fixed_size_list_array(list_size))) } + /// Build a single element [`ListViewArray`] + pub fn build_list_view_array(self) -> ListViewArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = ScalarBuffer::from(vec![0]); + let sizes = ScalarBuffer::from(vec![i32::try_from(arr.len()).expect( + "Trying to construct a ListView where element length exceeds i32::MAX", + )]); + ListViewArray::new(field, offsets, sizes, arr, None) + } + + /// Build a single element [`ListViewArray`] and wrap as [`ScalarValue::ListView`] + pub fn build_list_view_scalar(self) -> ScalarValue { + ScalarValue::ListView(Arc::new(self.build_list_view_array())) + } + + /// Build a single element [`LargeListViewArray`] + pub fn build_large_list_view_array(self) -> LargeListViewArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = ScalarBuffer::from(vec![0]); + let sizes = ScalarBuffer::from(vec![arr.len() as i64]); + LargeListViewArray::new(field, offsets, sizes, arr, None) + } + + /// Build a single element [`LargeListViewArray`] and wrap as [`ScalarValue::LargeListView`] + pub fn build_large_list_view_scalar(self) -> ScalarValue { + ScalarValue::LargeListView(Arc::new(self.build_large_list_view_array())) + } + /// Helper function: convert this builder into a tuple of field and array fn into_field_and_arr(self) -> (Arc, ArrayRef) { let Self { @@ -515,13 +551,12 @@ impl SingleRowListArrayBuilder { /// ); /// /// assert_eq!(list_arr, expected); +/// ``` pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { let arr = arr.into_iter().collect::>(); - if arr.is_empty() { - return _internal_err!("Cannot wrap empty array into list array"); - } + assert_or_internal_err!(!arr.is_empty(), "Cannot wrap empty array into list array"); let lens = arr.iter().map(|x| x.len()).collect::>(); // Assume data type is consistent @@ -564,11 +599,17 @@ pub fn base_type(data_type: &DataType) -> DataType { match data_type { DataType::List(field) | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) | DataType::FixedSizeList(field, _) => base_type(field.data_type()), _ => data_type.to_owned(), } } +// TODO: Modify this to also allow specifying how listviews should be treated. +// For example if cast to List (default) or maintain as ListView (requires +// function to implement support for ListViews) +// https://github.com/apache/datafusion/issues/21777 /// Information about how to coerce lists. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ListCoercion { @@ -588,6 +629,7 @@ pub enum ListCoercion { /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); +/// ``` pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, @@ -621,6 +663,19 @@ pub fn coerced_type_with_base_type_only( *len, ) } + (DataType::ListView(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::ListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } (DataType::LargeList(field), _) => { let field_type = coerced_type_with_base_type_only( field.data_type(), @@ -634,6 +689,19 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } + (DataType::LargeListView(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::LargeListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } _ => base_type.clone(), } @@ -651,6 +719,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { field.is_nullable(), ))) } + DataType::ListView(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::ListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } DataType::LargeList(field) => { let field_type = coerced_fixed_size_list_to_list(field.data_type()); @@ -660,6 +737,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { field.is_nullable(), ))) } + DataType::LargeListView(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::LargeListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } _ => data_type.clone(), } @@ -670,6 +756,8 @@ pub fn list_ndims(data_type: &DataType) -> u64 { match data_type { DataType::List(field) | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) | DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()), _ => 0, } @@ -694,10 +782,14 @@ pub mod datafusion_strsim { } /// Calculates the minimum number of insertions, deletions, and substitutions - /// required to change one sequence into the other. - fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + /// required to change one sequence into the other, using a reusable cache buffer. + /// + /// This is the generic implementation that works with any iterator types. + /// The `cache` buffer will be resized as needed and reused across calls. + fn generic_levenshtein_with_buffer<'a, 'b, Iter1, Iter2, Elem1, Elem2>( a: &'a Iter1, b: &'b Iter2, + cache: &mut Vec, ) -> usize where &'a Iter1: IntoIterator, @@ -710,7 +802,9 @@ pub mod datafusion_strsim { return b_len; } - let mut cache: Vec = (1..b_len + 1).collect(); + // Resize cache to fit b_len elements + cache.clear(); + cache.extend(1..=b_len); let mut result = 0; @@ -730,6 +824,21 @@ pub mod datafusion_strsim { result } + /// Calculates the minimum number of insertions, deletions, and substitutions + /// required to change one sequence into the other. + fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + a: &'a Iter1, + b: &'b Iter2, + ) -> usize + where + &'a Iter1: IntoIterator, + &'b Iter2: IntoIterator, + Elem1: PartialEq, + { + let mut cache = Vec::new(); + generic_levenshtein_with_buffer(a, b, &mut cache) + } + /// Calculates the minimum number of insertions, deletions, and substitutions /// required to change one string into the other. /// @@ -742,6 +851,15 @@ pub mod datafusion_strsim { generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) } + /// Calculates the Levenshtein distance using a reusable cache buffer. + /// This avoids allocating a new Vec for each call, improving performance + /// when computing many distances. + /// + /// The `cache` buffer will be resized as needed and reused across calls. + pub fn levenshtein_with_buffer(a: &str, b: &str, cache: &mut Vec) -> usize { + generic_levenshtein_with_buffer(&StringWrapper(a), &StringWrapper(b), cache) + } + /// Calculates the normalized Levenshtein distance between two strings. /// The normalized distance is a value between 0.0 and 1.0, where 1.0 indicates /// that the strings are identical and 0.0 indicates no similarity. @@ -891,10 +1009,15 @@ pub fn combine_limit( /// /// This is a wrapper around `std::thread::available_parallelism`, providing a default value /// of `1` if the system's parallelism cannot be determined. +/// +/// The result is cached after the first call. pub fn get_available_parallelism() -> usize { - available_parallelism() - .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) - .get() + static PARALLELISM: LazyLock = LazyLock::new(|| { + available_parallelism() + .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) + .get() + }); + *PARALLELISM } /// Converts a collection of function arguments into a fixed-size array of length N @@ -939,13 +1062,137 @@ pub fn take_function_args( }) } +/// Returns the inner values of a list, or an error otherwise +/// For [`ListArray`] and [`LargeListArray`], if it's sliced, it returns a +/// sliced array too. Therefore, too reconstruct a list using it, +/// you must adjust the offsets using [`adjust_offsets_for_slice`] +pub fn list_values(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(sliced_list_values(array.as_list::())), + DataType::LargeList(_) => Ok(sliced_list_values(array.as_list::())), + DataType::FixedSizeList(_, _) => { + Ok(Arc::clone(array.as_fixed_size_list().values())) + } + other => _exec_err!("expected list, got {other}"), + } +} + +fn sliced_list_values(list: &GenericListArray) -> ArrayRef { + let values = list.values(); + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { + let first = first.as_usize(); + let last = last.as_usize(); + + if first != 0 || last != values.len() { + return values.slice(first, last - first); + } + } + + Arc::clone(values) +} + +/// If `list` is sliced, returns an adjusted offset buffer so that +/// it points to the sliced portion of the list values, and not the whole list values +pub fn adjust_offsets_for_slice( + list: &GenericListArray, +) -> OffsetBuffer { + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) + && (!first.is_zero() || last.as_usize() != list.values().len()) + { + let offsets = offsets.iter().map(|offset| *offset - *first).collect(); + + //todo: use unsafe Offset::new_unchecked? + return OffsetBuffer::new(offsets); + } + + offsets.clone() +} + +/// For lists and large lists, truncates the sublist of null values +/// Otherwise returns an error +pub fn remove_list_null_values(array: &ArrayRef) -> Result { + // todo: handle list view and map + match array.data_type() { + DataType::List(_) => Ok(Arc::new(truncate_list_nulls(array.as_list::())?)), + DataType::LargeList(_) => { + Ok(Arc::new(truncate_list_nulls(array.as_list::())?)) + } + dt => _exec_err!("expected List or LargeList, got {dt}"), + } +} + +fn truncate_list_nulls( + list: &GenericListArray, +) -> Result> { + if let Some(nulls) = list.nulls() + && nulls.null_count() > 0 + { + let lengths = length(list)?; + let zero: &dyn Datum = if lengths.data_type() == &DataType::Int32 { + &Int32Array::new_scalar(0) + } else { + &Int64Array::new_scalar(0) + }; + + let not_empty = neq(&lengths, zero)?; + let null_and_non_empty = &!nulls.inner() & not_empty.values(); + + if null_and_non_empty.count_set_bits() > 0 { + let array_data = list.values().to_data(); + let offsets = list.offsets(); + let capacity = offsets[offsets.len() - 1] - offsets[0]; + let mut mutable_array_data = + MutableArrayData::new(vec![&array_data], false, capacity.as_usize()); + + let valid_or_empty = nulls.inner() | &!not_empty.values(); + + for (start, end) in valid_or_empty.set_slices() { + mutable_array_data.extend( + 0, + offsets[start].as_usize(), + offsets[end].as_usize(), + ); + } + + let lengths = std::iter::zip(offsets.lengths(), nulls) + .map(|(length, is_valid)| if is_valid { length } else { 0 }); + + let offsets = OffsetBuffer::from_lengths(lengths); + let values = make_array(mutable_array_data.freeze()); + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + _ => unreachable!(), + }; + + return Ok(GenericListArray::try_new( + Arc::clone(field), + offsets, + values, + list.nulls().cloned(), + )?); + } + } + Ok(list.clone()) +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::ScalarValue::Null; - use arrow::array::Float64Array; + use arrow::{ + array::{Float64Array, Int32Array}, + buffer::NullBuffer, + datatypes::Int32Type, + }; use sqlparser::ast::Ident; - use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -1174,7 +1421,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, - span: Span::empty(), + span: sqlparser::tokenizer::Span::empty(), }]; assert_eq!( @@ -1245,4 +1492,129 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } + + #[test] + fn test_sliced_list_values() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + sliced_list_values(&list).as_primitive(), + &Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + Some(6), + Some(7) + ]) + ); + + assert_eq!( + sliced_list_values(&list.slice(0, 1)).as_primitive(), + &Int32Array::from(vec![Some(0), Some(1), Some(2)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(2, 1)).as_primitive(), + &Int32Array::from(vec![Some(3), None, Some(5)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(3, 1)).as_primitive(), + &Int32Array::from(vec![Some(6), Some(7)]) + ); + + assert!(sliced_list_values(&list.slice(0, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(1, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(3, 0)).is_empty()); + } + + #[test] + fn test_adjust_offsets() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + adjust_offsets_for_slice(&list), + OffsetBuffer::from_lengths([3, 0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 1)), + OffsetBuffer::from_lengths([3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 2)), + OffsetBuffer::from_lengths([0, 3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 3)), + OffsetBuffer::from_lengths([0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(3, 0)), + OffsetBuffer::from_lengths([]) + ); + } + + fn create_i32_list( + values: impl Into, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + let list_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + + ListArray::new(list_field, offsets, Arc::new(values.into()), nulls) + } + + #[test] + fn test_remove_list_null_values_list() { + let list = Arc::new(create_i32_list( + vec![100, 20, 10, 0, 0, 0, 0, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 4, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + + let res = remove_list_null_values(&list).unwrap(); + let res = res.as_list::(); + + let expected = Arc::new(create_i32_list( + vec![100, 20, 10, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 0, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + let expected = expected.as_list::(); + + assert_eq!(res, expected); + // check above skips inner value of nulls + assert_eq!(res.values(), expected.values()); + assert_eq!(res.offsets(), expected.offsets()); + } } diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index fb951aa3b0289..846c928515d60 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations +//! [`VecAllocExt`] to help tracking of memory allocations -use hashbrown::{ - hash_table::HashTable, - raw::{Bucket, RawTable}, -}; +use hashbrown::hash_table::HashTable; use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. @@ -114,75 +111,6 @@ impl VecAllocExt for Vec { } } -/// Extension trait for hash browns [`RawTable`] to account for allocations. -pub trait RawTableAllocExt { - /// Item type. - type T; - - /// [Insert](RawTable::insert) new element into table and increase - /// `accounting` by any newly allocated bytes. - /// - /// Returns the bucket where the element was inserted. - /// Note that allocation counts capacity, not size. - /// - /// # Example: - /// ``` - /// # use datafusion_common::utils::proxy::RawTableAllocExt; - /// # use hashbrown::raw::RawTable; - /// let mut table = RawTable::new(); - /// let mut allocated = 0; - /// let hash_fn = |x: &u32| (*x as u64) % 1000; - /// // pretend 0x3117 is the hash value for 1 - /// table.insert_accounted(1, hash_fn, &mut allocated); - /// assert_eq!(allocated, 64); - /// - /// // insert more values - /// for i in 0..100 { - /// table.insert_accounted(i, hash_fn, &mut allocated); - /// } - /// assert_eq!(allocated, 400); - /// ``` - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket; -} - -impl RawTableAllocExt for RawTable { - type T = T; - - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket { - let hash = hasher(&x); - - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(x) => { - // need to request more memory - - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - - self.reserve(bump_elements, hasher); - - // still need to insert the element since first try failed - // Note: cannot use `.expect` here because `T` may not implement `Debug` - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(_) => panic!("just grew the container"), - } - } - } - } -} - /// Extension trait for hash browns [`HashTable`] to account for allocations. pub trait HashTableAllocExt { /// Item type. @@ -193,6 +121,8 @@ pub trait HashTableAllocExt { /// /// Returns the bucket where the element was inserted. /// Note that allocation counts capacity, not size. + /// Panics: + /// Assumes the element is not already present, and may panic if it does /// /// # Example: /// ``` @@ -206,7 +136,7 @@ pub trait HashTableAllocExt { /// assert_eq!(allocated, 64); /// /// // insert more values - /// for i in 0..100 { + /// for i in 2..100 { /// table.insert_accounted(i, hash_fn, &mut allocated); /// } /// assert_eq!(allocated, 400); @@ -233,22 +163,24 @@ where ) { let hash = hasher(&x); - // NOTE: `find_entry` does NOT grow! - match self.find_entry(hash, |y| y == &x) { - Ok(_occupied) => {} - Err(_absent) => { - if self.len() == self.capacity() { - // need to request more memory - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + if cfg!(debug_assertions) { + // In debug mode, check that the element is not already present + debug_assert!( + self.find_entry(hash, |y| y == &x).is_err(), + "attempted to insert duplicate element into HashTableAllocExt::insert_accounted" + ); + } - self.reserve(bump_elements, &hasher); - } + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - // still need to insert the element since first try failed - self.entry(hash, |y| y == &x, hasher).insert(x); - } + self.reserve(bump_elements, &hasher); } + + // We assume the element is not already present + self.insert_unique(hash, x, hasher); } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index f672e3a946816..ebd05392c926d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -19,7 +19,7 @@ name = "datafusion" description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" keywords = ["arrow", "query", "sql"] -include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml", "LICENSE.txt", "NOTICE.txt"] +include = ["benches/*.rs", "src/**/*.md", "src/**/*.rs", "Cargo.toml", "LICENSE.txt", "NOTICE.txt"] readme = "../../README.md" version = { workspace = true } edition = { workspace = true } @@ -32,6 +32,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -40,10 +43,10 @@ nested_expressions = ["datafusion-functions-nested"] # This feature is deprecated. Use the `nested_expressions` feature instead. array_expressions = ["nested_expressions"] # Used to enable the avro format -avro = ["datafusion-common/avro", "datafusion-datasource-avro"] +avro = ["datafusion-datasource-avro"] backtrace = ["datafusion-common/backtrace"] compression = [ - "xz2", + "liblzma", "bzip2", "flate2", "zstd", @@ -76,7 +79,6 @@ parquet_encryption = [ "datafusion-common/parquet_encryption", "datafusion-datasource-parquet/parquet_encryption", ] -pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ "datafusion-functions/regex_expressions", ] @@ -85,8 +87,9 @@ recursive_protection = [ "datafusion-expr/recursive_protection", "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", - "datafusion-sql/recursive_protection", - "sqlparser/recursive-protection", + "datafusion-physical-expr/recursive_protection", + "datafusion-sql?/recursive_protection", + "sqlparser?/recursive-protection", ] serde = [ "dep:serde", @@ -111,8 +114,7 @@ extended_tests = [] arrow = { workspace = true } arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } -bytes = { workspace = true } -bzip2 = { version = "0.6.1", optional = true } +bzip2 = { workspace = true, optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } @@ -140,24 +142,22 @@ datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } datafusion-sql = { workspace = true, optional = true } -flate2 = { version = "1.1.4", optional = true } +flate2 = { workspace = true, optional = true } futures = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } +liblzma = { workspace = true, optional = true } log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } -rand = { workspace = true } -regex = { workspace = true } -rstest = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.18", features = ["v4", "js"] } -xz2 = { version = "0.1", optional = true, features = ["static"] } -zstd = { version = "0.13", optional = true, default-features = false } +uuid = { workspace = true, features = ["v4", "js"] } +zstd = { workspace = true, optional = true } [dev-dependencies] async-trait = { workspace = true } @@ -169,16 +169,17 @@ datafusion-functions-window-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } +bytes = { workspace = true } env_logger = { workspace = true } -glob = { version = "0.3.0" } +glob = { workspace = true } insta = { workspace = true } -paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.5" +recursive = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.37.2" +sysinfo = "0.38.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } @@ -186,7 +187,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] ignored = ["datafusion-doc", "datafusion-macros", "dashmap"] [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } [[bench]] harness = false @@ -224,6 +225,10 @@ name = "struct_query_sql" harness = false name = "window_query_sql" +[[bench]] +harness = false +name = "topk_repartition" + [[bench]] harness = false name = "scalar" @@ -237,6 +242,20 @@ harness = false name = "parquet_query_sql" required-features = ["parquet"] +[[bench]] +harness = false +name = "parquet_struct_query" +required-features = ["parquet"] + +[[bench]] +harness = false +name = "parquet_struct_projection" +required-features = ["parquet"] + +[[bench]] +harness = false +name = "range_and_generate_series" + [[bench]] harness = false name = "sql_planner" @@ -269,3 +288,12 @@ name = "dataframe" [[bench]] harness = false name = "spm" + +[[bench]] +harness = false +name = "preserve_file_partitioning" +required-features = ["parquet"] + +[[bench]] +harness = false +name = "reset_plan_states" diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 87aeed49337eb..d7e24aceba170 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; @@ -31,6 +26,7 @@ use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); @@ -255,6 +251,83 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("array_agg_query_group_by_few_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_narrow, array_agg(f64) \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(f64) \ + FROM t GROUP BY u64_mid", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_many_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_wide, array_agg(f64) \ + FROM t GROUP BY u64_wide", + ) + }) + }); + + c.bench_function("array_agg_struct_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(named_struct('market', dict10, 'price', f64)) \ + FROM t GROUP BY u64_mid", + ) + }) + }); + + c.bench_function("string_agg_query_group_by_few_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_narrow, string_agg(utf8, ',') \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("string_agg_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, string_agg(utf8, ',') \ + FROM t GROUP BY u64_mid", + ) + }) + }); + + c.bench_function("string_agg_query_group_by_many_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_wide, string_agg(utf8, ',') \ + FROM t GROUP BY u64_wide", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs index de0f0d8250572..13843dadddd0c 100644 --- a/datafusion/core/benches/csv_load.rs +++ b/datafusion/core/benches/csv_load.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; @@ -34,6 +29,7 @@ use std::time::Duration; use test_utils::AccessLogGenerator; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn load_csv( ctx: Arc>, rt: &Runtime, diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index fffe2e2d17522..728c6490c72bd 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -18,10 +18,11 @@ //! This module provides the in-memory table for more realistic benchmarking. use arrow::array::{ - builder::{Int64Builder, StringBuilder}, ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder, UInt64Array, + builder::{Int64Builder, StringBuilder, StringDictionaryBuilder}, }; +use arrow::datatypes::Int32Type; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; @@ -36,6 +37,7 @@ use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, /// and the result table will be of array_len in total, and then partitioned, and batched. +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub fn create_table_provider( partitions_len: usize, @@ -44,7 +46,7 @@ pub fn create_table_provider( ) -> Result> { let schema = Arc::new(create_schema()); let partitions = - create_record_batches(schema.clone(), array_len, partitions_len, batch_size); + create_record_batches(&schema, array_len, partitions_len, batch_size); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). MemTable::try_new(schema, partitions).map(Arc::new) } @@ -55,21 +57,24 @@ pub fn create_schema() -> Schema { Field::new("utf8", DataType::Utf8, false), Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, true), - // This field will contain integers randomly selected from a large - // range of values, i.e. [0, u64::MAX], such that there are none (or - // very few) repeated values. - Field::new("u64_wide", DataType::UInt64, true), - // This field will contain integers randomly selected from a narrow - // range of values such that there are a few distinct values, but they - // are repeated often. + // Integers randomly selected from a wide range of values, i.e. [0, + // u64::MAX], such that there are ~no repeated values. + Field::new("u64_wide", DataType::UInt64, false), + // Integers randomly selected from a mid-range of values [0, 1000), + // providing ~1000 distinct groups. + Field::new("u64_mid", DataType::UInt64, false), + // Integers randomly selected from a narrow range of values such that + // there are a few distinct values, but they are repeated often. Field::new("u64_narrow", DataType::UInt64, false), + Field::new( + "dict10", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), ]) } -fn create_data(size: usize, null_density: f64) -> Vec> { - // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = StdRng::seed_from_u64(42); - +fn create_data(rng: &mut StdRng, size: usize, null_density: f64) -> Vec> { (0..size) .map(|_| { if rng.random::() > null_density { @@ -81,57 +86,54 @@ fn create_data(size: usize, null_density: f64) -> Vec> { .collect() } -fn create_integer_data( - rng: &mut StdRng, - size: usize, - value_density: f64, -) -> Vec> { - (0..size) - .map(|_| { - if rng.random::() > value_density { - None - } else { - Some(rng.random::()) - } - }) - .collect() -} - fn create_record_batch( schema: SchemaRef, rng: &mut StdRng, batch_size: usize, - i: usize, + batch_index: usize, ) -> RecordBatch { - // the 4 here is the number of different keys. - // a higher number increase sparseness - let vs = [0, 1, 2, 3]; - let keys: Vec = (0..batch_size) - .map( - // use random numbers to avoid spurious compiler optimizations wrt to branching - |_| format!("hi{:?}", vs.choose(rng)), - ) - .collect(); - let keys: Vec<&str> = keys.iter().map(|e| &**e).collect(); + // Randomly choose from 4 distinct key values; a higher number increases sparseness. + let key_suffixes = [0, 1, 2, 3]; + let keys = StringArray::from_iter_values( + (0..batch_size).map(|_| format!("hi{}", key_suffixes.choose(rng).unwrap())), + ); - let values = create_data(batch_size, 0.5); + let values = create_data(rng, batch_size, 0.5); // Integer values between [0, u64::MAX]. - let integer_values_wide = create_integer_data(rng, batch_size, 9.0); + let integer_values_wide = (0..batch_size) + .map(|_| rng.random::()) + .collect::>(); - // Integer values between [0, 9]. + // Integer values between [0, 1000). + let integer_values_mid = (0..batch_size) + .map(|_| rng.random_range(0..1000)) + .collect::>(); + + // Integer values between [0, 10). let integer_values_narrow = (0..batch_size) - .map(|_| rng.random_range(0_u64..10)) + .map(|_| rng.random_range(0..10)) .collect::>(); + let mut dict_builder = StringDictionaryBuilder::::new(); + for _ in 0..batch_size { + if rng.random::() > 0.9 { + dict_builder.append_null(); + } else { + dict_builder.append_value(format!("market_{}", rng.random_range(0..10))); + } + } + RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(keys)), - Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(keys), + Arc::new(Float32Array::from(vec![batch_index as f32; batch_size])), Arc::new(Float64Array::from(values)), Arc::new(UInt64Array::from(integer_values_wide)), + Arc::new(UInt64Array::from(integer_values_mid)), Arc::new(UInt64Array::from(integer_values_narrow)), + Arc::new(dict_builder.finish()), ], ) .unwrap() @@ -140,19 +142,28 @@ fn create_record_batch( /// Create record batches of `partitions_len` partitions and `batch_size` for each batch, /// with a total number of `array_len` records pub fn create_record_batches( - schema: SchemaRef, + schema: &SchemaRef, array_len: usize, partitions_len: usize, batch_size: usize, ) -> Vec> { let mut rng = StdRng::seed_from_u64(42); - (0..partitions_len) - .map(|_| { - (0..array_len / batch_size / partitions_len) - .map(|i| create_record_batch(schema.clone(), &mut rng, batch_size, i)) - .collect::>() - }) - .collect::>() + let mut partitions = Vec::with_capacity(partitions_len); + let batches_per_partition = array_len / batch_size / partitions_len; + + for _ in 0..partitions_len { + let mut batches = Vec::with_capacity(batches_per_partition); + for batch_index in 0..batches_per_partition { + batches.push(create_record_batch( + schema.clone(), + &mut rng, + batch_size, + batch_index, + )); + } + partitions.push(batches); + } + partitions } /// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder @@ -182,6 +193,7 @@ impl TraceIdBuilder { /// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition /// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub(crate) fn make_data( partition_cnt: i32, diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 00fa85918347a..5aeade315cc7b 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -15,13 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - use arrow_schema::{DataType, Field, Schema}; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_expr::col; @@ -45,6 +40,7 @@ fn create_context(field_count: u32) -> datafusion_common::Result, rt: &Runtime) { black_box(rt.block_on(async { let mut data_frame = ctx.table("t").await.unwrap(); diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index d05e8b13b2af3..d389b1b3d6a22 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::{create_table_provider, make_data}; use datafusion::execution::context::SessionContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use datafusion_execution::config::SessionConfig; use parking_lot::Mutex; use std::hint::black_box; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); @@ -124,6 +121,7 @@ async fn distinct_with_limit( Ok(()) } +#[expect(clippy::needless_pass_by_value)] fn run(rt: &Runtime, plan: Arc, ctx: Arc) { black_box(rt.block_on(distinct_with_limit(plan.clone(), ctx.clone()))).unwrap(); } diff --git a/datafusion/core/benches/filter_query_sql.rs b/datafusion/core/benches/filter_query_sql.rs index 16905e0f96605..3b80518d32dcd 100644 --- a/datafusion/core/benches/filter_query_sql.rs +++ b/datafusion/core/benches/filter_query_sql.rs @@ -20,7 +20,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 09234546b2dfe..67904197bc257 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::hint::black_box; use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, RecordBatch}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; -use rand::prelude::ThreadRng; use rand::Rng; +use rand::prelude::ThreadRng; use tokio::runtime::Runtime; use datafusion::prelude::SessionContext; @@ -33,11 +34,12 @@ use datafusion_functions_nested::map::map; mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { - let mut keys = vec![]; - for _ in 0..1000 { - keys.push(rng.random_range(0..9999).to_string()); + let mut keys = HashSet::with_capacity(1000); + while keys.len() < 1000 { + let key = rng.random_range(0..9999).to_string(); + keys.insert(key); } - keys + keys.into_iter().collect() } fn build_values(rng: &mut ThreadRng) -> Vec { diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 76824850c114c..f5df56e95a2d8 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -extern crate arrow; -extern crate datafusion; - use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, @@ -36,6 +31,7 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index e2b3810480130..f099137973592 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -23,14 +23,14 @@ use arrow::datatypes::{ SchemaRef, }; use arrow::record_batch::RecordBatch; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distr::uniform::SampleUniform; use rand::distr::Alphanumeric; +use rand::distr::uniform::SampleUniform; use rand::prelude::*; use rand::rng; use std::fs::File; @@ -45,7 +45,7 @@ const NUM_BATCHES: usize = 2048; /// The number of rows in each record batch to write const WRITE_RECORD_BATCH_SIZE: usize = 1024; /// The number of rows in a row group -const ROW_GROUP_SIZE: usize = 1024 * 1024; +const ROW_GROUP_ROW_COUNT: usize = 1024 * 1024; /// The number of row groups expected const EXPECTED_ROW_GROUPS: usize = 2; @@ -154,7 +154,7 @@ fn generate_file() -> NamedTempFile { let properties = WriterProperties::builder() .set_writer_version(WriterVersion::PARQUET_2_0) - .set_max_row_group_size(ROW_GROUP_SIZE) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) .build(); let mut writer = diff --git a/datafusion/core/benches/parquet_struct_projection.rs b/datafusion/core/benches/parquet_struct_projection.rs new file mode 100644 index 0000000000000..7d5b220d397f8 --- /dev/null +++ b/datafusion/core/benches/parquet_struct_projection.rs @@ -0,0 +1,496 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for struct leaf-level projection pruning in Parquet. +//! +//! Measures the benefit of reading only the needed leaf columns from a +//! struct column. Three dataset shapes are tested: +//! +//! 1. **Narrow struct** (2 leaves): one 128 KiB UTF-8 field + one INT field +//! 2. **Wide struct** (5 leaves): four 128 KiB UTF-8 fields + one INT field +//! 3. **Nested struct** (3 leaves): `STRUCT, extra_string>` +//! +//! In all cases, projecting just the small integer should skip decoding +//! all of the large string leaves, including through nested struct levels. + +use arrow::array::{ArrayRef, Int32Array, StringBuilder, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_common::instant::Instant; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::{WriterProperties, WriterVersion}; +use std::hint::black_box; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; +use tempfile::NamedTempFile; +use tokio::runtime::Runtime; + +const NUM_BATCHES: usize = 2; +const WRITE_RECORD_BATCH_SIZE: usize = 256; +const ROW_GROUP_ROW_COUNT: usize = 256; +const EXPECTED_ROW_GROUPS: usize = 2; +const LARGE_STRING_LEN: usize = 16 * 1024; + +fn narrow_schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("large_string", DataType::Utf8, false), + Field::new("small_int", DataType::Int32, false), + ]); + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(struct_fields), false), + ])) +} + +fn narrow_batch(batch_id: usize) -> RecordBatch { + let schema = narrow_schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + + let small_int_array = Arc::new(Int32Array::from(id_values)); + + let large_string: String = "x".repeat(LARGE_STRING_LEN); + let mut string_builder = StringBuilder::new(); + for _ in 0..len { + string_builder.append_value(&large_string); + } + let large_string_array = Arc::new(string_builder.finish()); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("large_string", DataType::Utf8, false)), + large_string_array as ArrayRef, + ), + ( + Arc::new(Field::new("small_int", DataType::Int32, false)), + small_int_array as ArrayRef, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn wide_schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("str_a", DataType::Utf8, false), + Field::new("str_b", DataType::Utf8, false), + Field::new("str_c", DataType::Utf8, false), + Field::new("str_d", DataType::Utf8, false), + Field::new("small_int", DataType::Int32, false), + ]); + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(struct_fields), false), + ])) +} + +fn wide_batch(batch_id: usize) -> RecordBatch { + let schema = wide_schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + let small_int_array = Arc::new(Int32Array::from(id_values)); + + let large_string: String = "x".repeat(LARGE_STRING_LEN); + let mut string_fields: Vec<(Arc, ArrayRef)> = Vec::new(); + for name in &["str_a", "str_b", "str_c", "str_d"] { + let mut sb = StringBuilder::new(); + for _ in 0..len { + sb.append_value(&large_string); + } + string_fields.push(( + Arc::new(Field::new(*name, DataType::Utf8, false)), + Arc::new(sb.finish()) as ArrayRef, + )); + } + string_fields.push(( + Arc::new(Field::new("small_int", DataType::Int32, false)), + small_int_array as ArrayRef, + )); + + let struct_array = StructArray::from(string_fields); + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn generate_file( + schema: SchemaRef, + batch_fn: fn(usize) -> RecordBatch, + prefix: &str, +) -> NamedTempFile { + let now = Instant::now(); + let mut named_file = tempfile::Builder::new() + .prefix(prefix) + .suffix(".parquet") + .tempfile() + .unwrap(); + + println!("Generating parquet file - {}", named_file.path().display()); + + let properties = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut named_file, schema, Some(properties)).unwrap(); + + for batch_id in 0..NUM_BATCHES { + let batch = batch_fn(batch_id); + writer.write(&batch).unwrap(); + } + + let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); + let expected_rows = WRITE_RECORD_BATCH_SIZE * NUM_BATCHES; + assert_eq!( + file_metadata.num_rows() as usize, + expected_rows, + "Expected {expected_rows} rows but got {}", + file_metadata.num_rows() + ); + assert_eq!( + metadata.row_groups().len(), + EXPECTED_ROW_GROUPS, + "Expected {EXPECTED_ROW_GROUPS} row groups but got {}", + metadata.row_groups().len() + ); + + println!( + "Generated parquet file with {} rows and {} row groups in {:.2}s", + file_metadata.num_rows(), + metadata.row_groups().len(), + now.elapsed().as_secs_f32() + ); + + named_file +} + +fn create_context(rt: &Runtime, file_path: &str, table: &str) -> SessionContext { + let ctx = SessionContext::new(); + rt.block_on(ctx.register_parquet(table, file_path, Default::default())) + .unwrap(); + ctx +} + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let ctx = ctx.clone(); + let sql = sql.to_string(); + let df = rt.block_on(ctx.sql(&sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn narrow_benchmarks(c: &mut Criterion) { + let temp_file = generate_file(narrow_schema(), narrow_batch, "narrow_struct"); + let file_path = temp_file.path().display().to_string(); + assert!(Path::new(&file_path).exists(), "path not found"); + + let rt = Runtime::new().unwrap(); + let ctx = create_context(&rt, &file_path, "t"); + + let mut group = c.benchmark_group("narrow_struct"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(2)); + + // baseline: full struct, must decode both leaves + group.bench_function("select_struct", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s FROM t")) + }); + + // pruned: skip large_string, read only small_int + group.bench_function("select_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['small_int'] FROM t")) + }); + + // pruned: skip small_int, read only large_string + group.bench_function("select_large_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['large_string'] FROM t")) + }); + + // no pruning: all columns + group.bench_function("select_all", |b| { + b.iter(|| query(&ctx, &rt, "SELECT * FROM t")) + }); + + // top-level column + pruned struct sub-field + group.bench_function("select_id_and_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT id, s['small_int'] FROM t")) + }); + + // aggregation on pruned sub-field, realistic analytical pattern + group.bench_function("sum_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT SUM(s['small_int']) FROM t")) + }); + + group.finish(); + drop(temp_file); +} + +fn wide_benchmarks(c: &mut Criterion) { + let temp_file = generate_file(wide_schema(), wide_batch, "wide_struct"); + let file_path = temp_file.path().display().to_string(); + assert!(Path::new(&file_path).exists(), "path not found"); + + let rt = Runtime::new().unwrap(); + let ctx = create_context(&rt, &file_path, "t"); + + let mut group = c.benchmark_group("wide_struct"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(2)); + + // baseline: full struct, must decode all 5 leaves + group.bench_function("select_struct", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s FROM t")) + }); + + // pruned: skip all 4 large string leaves + group.bench_function("select_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['small_int'] FROM t")) + }); + + // pruned: read 1 of 4 string leaves + skip the rest + group.bench_function("select_one_string_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['str_a'] FROM t")) + }); + + // pruned: read 2 of 4 string leaves + group.bench_function("select_two_string_fields", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['str_a'], s['str_b'] FROM t")) + }); + + // no pruning: all columns + group.bench_function("select_all", |b| { + b.iter(|| query(&ctx, &rt, "SELECT * FROM t")) + }); + + // aggregation on pruned sub-field, skips all 4 large leaves + group.bench_function("sum_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT SUM(s['small_int']) FROM t")) + }); + + group.finish(); + drop(temp_file); +} + +fn nested_schema() -> SchemaRef { + let inner_fields = Fields::from(vec![ + Field::new("large_string", DataType::Utf8, false), + Field::new("small_int", DataType::Int32, false), + ]); + let outer_fields = Fields::from(vec![ + Field::new("inner", DataType::Struct(inner_fields), false), + Field::new("extra_string", DataType::Utf8, false), + ]); + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(outer_fields), false), + ])) +} + +fn nested_batch(batch_id: usize) -> RecordBatch { + let schema = nested_schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + let small_int_array = Arc::new(Int32Array::from(id_values)); + + let large_string: String = "x".repeat(LARGE_STRING_LEN); + + let mut sb1 = StringBuilder::new(); + let mut sb2 = StringBuilder::new(); + for _ in 0..len { + sb1.append_value(&large_string); + sb2.append_value(&large_string); + } + + let inner_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("large_string", DataType::Utf8, false)), + Arc::new(sb1.finish()) as ArrayRef, + ), + ( + Arc::new(Field::new("small_int", DataType::Int32, false)), + small_int_array as ArrayRef, + ), + ]); + + let inner_fields = Fields::from(vec![ + Field::new("large_string", DataType::Utf8, false), + Field::new("small_int", DataType::Int32, false), + ]); + let outer_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("inner", DataType::Struct(inner_fields), false)), + Arc::new(inner_struct) as ArrayRef, + ), + ( + Arc::new(Field::new("extra_string", DataType::Utf8, false)), + Arc::new(sb2.finish()) as ArrayRef, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(outer_struct)]).unwrap() +} + +fn nested_benchmarks(c: &mut Criterion) { + let temp_file = generate_file(nested_schema(), nested_batch, "nested_struct"); + let file_path = temp_file.path().display().to_string(); + assert!(Path::new(&file_path).exists(), "path not found"); + + let rt = Runtime::new().unwrap(); + let ctx = create_context(&rt, &file_path, "t"); + + let mut group = c.benchmark_group("nested_struct"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(2)); + + // baseline: full outer struct, decode all 3 leaves + group.bench_function("select_struct", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s FROM t")) + }); + + // pruned outer: read only inner struct, skip extra_string + group.bench_function("select_inner_struct", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['inner'] FROM t")) + }); + + // pruned both levels: reach through outer + inner, skip both large strings + group.bench_function("select_inner_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['inner']['small_int'] FROM t")) + }); + + // pruned outer only: skip inner struct entirely, read extra_string + group.bench_function("select_extra_string", |b| { + b.iter(|| query(&ctx, &rt, "SELECT s['extra_string'] FROM t")) + }); + + // no pruning: all columns + group.bench_function("select_all", |b| { + b.iter(|| query(&ctx, &rt, "SELECT * FROM t")) + }); + + // aggregation reaching through two levels of nesting + group.bench_function("sum_inner_small_field", |b| { + b.iter(|| query(&ctx, &rt, "SELECT SUM(s['inner']['small_int']) FROM t")) + }); + + group.finish(); + drop(temp_file); +} + +fn flat_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("large_string", DataType::Utf8, false), + Field::new("small_int", DataType::Int32, false), + ])) +} + +fn flat_batch(batch_id: usize) -> RecordBatch { + let schema = flat_schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + let small_int_array = Arc::new(Int32Array::from(id_values)); + + let large_string: String = "x".repeat(LARGE_STRING_LEN); + let mut string_builder = StringBuilder::new(); + for _ in 0..len { + string_builder.append_value(&large_string); + } + let large_string_array = Arc::new(string_builder.finish()); + + RecordBatch::try_new( + schema, + vec![id_array, large_string_array as ArrayRef, small_int_array], + ) + .unwrap() +} + +/// Compare selecting a small field from a flat (top-level) schema vs from +/// inside a struct. Both files contain the same logical data — the only +/// difference is whether `small_int` lives at the top level or nested inside +/// a struct column. +fn flat_vs_struct_benchmarks(c: &mut Criterion) { + let flat_file = generate_file(flat_schema(), flat_batch, "flat"); + let flat_path = flat_file.path().display().to_string(); + assert!(Path::new(&flat_path).exists(), "path not found"); + + let struct_file = generate_file(narrow_schema(), narrow_batch, "narrow_struct_cmp"); + let struct_path = struct_file.path().display().to_string(); + assert!(Path::new(&struct_path).exists(), "path not found"); + + let rt = Runtime::new().unwrap(); + let flat_ctx = create_context(&rt, &flat_path, "t"); + let struct_ctx = create_context(&rt, &struct_path, "t"); + + let mut group = c.benchmark_group("flat_vs_struct"); + group.sample_size(10); + group.warm_up_time(Duration::from_secs(1)); + group.measurement_time(Duration::from_secs(2)); + + // small int: top-level vs struct field + group.bench_function("flat_select_small_int", |b| { + b.iter(|| query(&flat_ctx, &rt, "SELECT small_int FROM t")) + }); + group.bench_function("struct_select_small_int", |b| { + b.iter(|| query(&struct_ctx, &rt, "SELECT s['small_int'] FROM t")) + }); + + // large string: top-level vs struct field + group.bench_function("flat_select_large_string", |b| { + b.iter(|| query(&flat_ctx, &rt, "SELECT large_string FROM t")) + }); + group.bench_function("struct_select_large_string", |b| { + b.iter(|| query(&struct_ctx, &rt, "SELECT s['large_string'] FROM t")) + }); + + // aggregation: SUM of small int + group.bench_function("flat_sum_small_int", |b| { + b.iter(|| query(&flat_ctx, &rt, "SELECT SUM(small_int) FROM t")) + }); + group.bench_function("struct_sum_small_int", |b| { + b.iter(|| query(&struct_ctx, &rt, "SELECT SUM(s['small_int']) FROM t")) + }); + + group.finish(); + drop(flat_file); + drop(struct_file); +} + +criterion_group!( + benches, + narrow_benchmarks, + wide_benchmarks, + nested_benchmarks, + flat_vs_struct_benchmarks, +); +criterion_main!(benches); diff --git a/datafusion/core/benches/parquet_struct_query.rs b/datafusion/core/benches/parquet_struct_query.rs new file mode 100644 index 0000000000000..e7e91f0dd0e1e --- /dev/null +++ b/datafusion/core/benches/parquet_struct_query.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks of SQL queries on struct columns in parquet data + +use arrow::array::{ArrayRef, Int32Array, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_common::instant::Instant; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::{WriterProperties, WriterVersion}; +use rand::distr::Alphanumeric; +use rand::prelude::*; +use rand::rng; +use std::hint::black_box; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use tempfile::NamedTempFile; +use tokio::runtime::Runtime; + +/// The number of batches to write +const NUM_BATCHES: usize = 128; +/// The number of rows in each record batch to write +const WRITE_RECORD_BATCH_SIZE: usize = 4096; +/// The number of rows in a row group +const ROW_GROUP_ROW_COUNT: usize = 65536; +/// The number of row groups expected +const EXPECTED_ROW_GROUPS: usize = 8; +/// The range for random string lengths +const STRING_LENGTH_RANGE: Range = 50..200; + +fn schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ]); + let struct_type = DataType::Struct(struct_fields); + + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", struct_type, false), + ])) +} + +fn generate_strings(len: usize) -> ArrayRef { + let mut rng = rng(); + Arc::new(StringArray::from_iter((0..len).map(|_| { + let string_len = rng.random_range(STRING_LENGTH_RANGE.clone()); + Some( + (0..string_len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect::(), + ) + }))) +} + +fn generate_batch(batch_id: usize) -> RecordBatch { + let schema = schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + // Generate sequential IDs based on batch_id for uniqueness + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + + // Create struct id array (matching top-level id) + let struct_id_array = Arc::new(Int32Array::from(id_values)); + + // Generate random strings for struct value field + let value_array = generate_strings(len); + + // Construct StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + struct_id_array as ArrayRef, + ), + ( + Arc::new(Field::new("value", DataType::Utf8, false)), + value_array, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn generate_file() -> NamedTempFile { + let now = Instant::now(); + let mut named_file = tempfile::Builder::new() + .prefix("parquet_struct_query") + .suffix(".parquet") + .tempfile() + .unwrap(); + + println!("Generating parquet file - {}", named_file.path().display()); + let schema = schema(); + + let properties = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut named_file, schema, Some(properties)).unwrap(); + + for batch_id in 0..NUM_BATCHES { + let batch = generate_batch(batch_id); + writer.write(&batch).unwrap(); + } + + let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); + let expected_rows = WRITE_RECORD_BATCH_SIZE * NUM_BATCHES; + assert_eq!( + file_metadata.num_rows() as usize, + expected_rows, + "Expected {} rows but got {}", + expected_rows, + file_metadata.num_rows() + ); + assert_eq!( + metadata.row_groups().len(), + EXPECTED_ROW_GROUPS, + "Expected {} row groups but got {}", + EXPECTED_ROW_GROUPS, + metadata.row_groups().len() + ); + + println!( + "Generated parquet file with {} rows and {} row groups in {} seconds", + file_metadata.num_rows(), + metadata.row_groups().len(), + now.elapsed().as_secs_f32() + ); + + named_file +} + +fn create_context(file_path: &str) -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + rt.block_on(ctx.register_parquet("t", file_path, Default::default())) + .unwrap(); + ctx +} + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let ctx = ctx.clone(); + let sql = sql.to_string(); + let df = rt.block_on(ctx.sql(&sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn criterion_benchmark(c: &mut Criterion) { + let (file_path, temp_file) = match std::env::var("PARQUET_FILE") { + Ok(file) => (file, None), + Err(_) => { + let temp_file = generate_file(); + (temp_file.path().display().to_string(), Some(temp_file)) + } + }; + + assert!(Path::new(&file_path).exists(), "path not found"); + println!("Using parquet file {file_path}"); + + let ctx = create_context(&file_path); + let rt = Runtime::new().unwrap(); + + // Basic struct access + c.bench_function("struct_access", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t")) + }); + + // Filter queries + c.bench_function("filter_struct_field_eq", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where s['id'] = 5")) + }); + + c.bench_function("filter_struct_field_with_select", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t where s['id'] = 5")) + }); + + c.bench_function("filter_top_level_with_struct_select", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t where id = 5")) + }); + + c.bench_function("filter_struct_string_length", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where length(s['value']) > 100")) + }); + + c.bench_function("filter_struct_range", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id from t where s['id'] > 100 and s['id'] < 200", + ) + }) + }); + + // Join queries (limited with WHERE id < 1000 for performance) + c.bench_function("join_struct_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_toplevel", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_toplevel_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.id = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_struct_with_top_level", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] and t1.id = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_and_struct_value", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.s['id'], t2.s['value'] from t t1 join t t2 on t1.id = t2.id where t1.id < 1000" + )) + }); + + // Group by queries + c.bench_function("group_by_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t group by s['id']")) + }); + + c.bench_function("group_by_struct_select_toplevel", |b| { + b.iter(|| query(&ctx, &rt, "select max(id) from t group by s['id']")) + }); + + c.bench_function("group_by_toplevel_select_struct", |b| { + b.iter(|| query(&ctx, &rt, "select max(s['id']) from t group by id")) + }); + + c.bench_function("group_by_struct_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select s['id'], count(*) from t group by s['id']", + ) + }) + }); + + c.bench_function("group_by_multiple_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'], count(*) from t group by id, s['id']", + ) + }) + }); + + // Additional queries + c.bench_function("order_by_struct_limit", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'] from t order by s['id'] limit 1000", + ) + }) + }); + + c.bench_function("distinct_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select distinct s['id'] from t")) + }); + + // Temporary file must outlive the benchmarks, it is deleted when dropped + drop(temp_file); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index e4838572f60fb..7b66996b05929 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::{BatchSize, Criterion}; -extern crate arrow; -extern crate datafusion; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use std::sync::Arc; @@ -32,7 +28,7 @@ use tokio::runtime::Runtime; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, }; use datafusion::prelude::SessionContext; use datafusion_datasource::memory::MemorySourceConfig; @@ -40,6 +36,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; // Initialize the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. +#[expect(clippy::needless_pass_by_value)] fn sort_preserving_merge_operator( session_ctx: Arc, rt: &Runtime, diff --git a/datafusion/core/benches/preserve_file_partitioning.rs b/datafusion/core/benches/preserve_file_partitioning.rs new file mode 100644 index 0000000000000..9b1f59adc6823 --- /dev/null +++ b/datafusion/core/benches/preserve_file_partitioning.rs @@ -0,0 +1,838 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `preserve_file_partitions` optimization. +//! +//! When enabled, this optimization declares Hive-partitioned tables as +//! `Hash([partition_col])` partitioned, allowing the query optimizer to +//! skip unnecessary repartitioning and sorting operations. +//! +//! When This Optimization Helps +//! - Window functions: PARTITION BY on partition column eliminates RepartitionExec and SortExec +//! - Aggregates with ORDER BY: GROUP BY partition column and ORDER BY eliminates post aggregate sort +//! +//! When This Optimization Does NOT Help +//! - GROUP BY non-partition columns: Required Hash distribution doesn't match declared partitioning +//! - When the number of distinct file partitioning groups < the number of CPUs available: Reduces +//! parallelization, thus may outweigh the pros of reduced shuffles +//! +//! Usage +//! - BENCH_SIZE=small|medium|large cargo bench -p datafusion --bench preserve_file_partitions +//! - SAVE_PLANS=1 cargo bench ... # Save query plans to files + +use arrow::array::{ArrayRef, Float64Array, StringArray, TimestampMillisecondArray}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext, col}; +use datafusion_expr::SortExpr; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use std::fs::{self, File}; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Copy)] +struct BenchConfig { + fact_partitions: usize, + rows_per_partition: usize, + target_partitions: usize, + measurement_time_secs: u64, +} + +impl BenchConfig { + fn small() -> Self { + Self { + fact_partitions: 10, + rows_per_partition: 1_000_000, + target_partitions: 10, + measurement_time_secs: 15, + } + } + + fn medium() -> Self { + Self { + fact_partitions: 30, + rows_per_partition: 3_000_000, + target_partitions: 30, + measurement_time_secs: 30, + } + } + + fn large() -> Self { + Self { + fact_partitions: 50, + rows_per_partition: 6_000_000, + target_partitions: 50, + measurement_time_secs: 90, + } + } + + fn from_env() -> Self { + match std::env::var("BENCH_SIZE").as_deref() { + Ok("small") | Ok("SMALL") => Self::small(), + Ok("medium") | Ok("MEDIUM") => Self::medium(), + Ok("large") | Ok("LARGE") => Self::large(), + _ => { + println!("Using SMALL dataset (set BENCH_SIZE=small|medium|large)"); + Self::small() + } + } + } + + fn total_rows(&self) -> usize { + self.fact_partitions * self.rows_per_partition + } + + fn high_cardinality(base: &Self) -> Self { + Self { + fact_partitions: (base.fact_partitions as f64 * 2.5) as usize, + rows_per_partition: base.rows_per_partition / 2, + target_partitions: base.target_partitions, + measurement_time_secs: base.measurement_time_secs, + } + } +} + +fn dkey_names(count: usize) -> Vec { + (0..count) + .map(|i| { + if i < 26 { + ((b'A' + i as u8) as char).to_string() + } else { + format!( + "{}{}", + (b'A' + ((i / 26) - 1) as u8) as char, + (b'A' + (i % 26) as u8) as char + ) + } + }) + .collect() +} + +/// Hive-partitioned fact table, sorted by timestamp within each partition. +fn generate_fact_table( + base_dir: &Path, + num_partitions: usize, + rows_per_partition: usize, +) { + let fact_dir = base_dir.join("fact"); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let dkeys = dkey_names(num_partitions); + + for dkey in &dkeys { + let part_dir = fact_dir.join(format!("f_dkey={dkey}")); + fs::create_dir_all(&part_dir).unwrap(); + let file_path = part_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + + let mut writer = + ArrowWriter::try_new(file, schema.clone(), Some(props.clone())).unwrap(); + + let base_ts = 1672567200000i64; // 2023-01-01T09:00:00 + let timestamps: Vec = (0..rows_per_partition) + .map(|i| base_ts + (i as i64 * 10000)) + .collect(); + + let values: Vec = (0..rows_per_partition) + .map(|i| 50.0 + (i % 100) as f64 + ((i % 7) as f64 * 10.0)) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(TimestampMillisecondArray::from(timestamps)) as ArrayRef, + Arc::new(Float64Array::from(values)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } +} + +/// Single-file dimension table for CollectLeft joins. +fn generate_dimension_table(base_dir: &Path, num_partitions: usize) { + let dim_dir = base_dir.join("dimension"); + fs::create_dir_all(&dim_dir).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let file_path = dim_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap(); + + let dkeys = dkey_names(num_partitions); + let envs = ["dev", "prod", "staging", "test"]; + let services = ["log", "trace", "metric"]; + let hosts = ["ma", "vim", "nano", "emacs"]; + + let d_dkey_vals: Vec = dkeys.clone(); + let env_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| envs[i % envs.len()].to_string()) + .collect(); + let service_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| services[i % services.len()].to_string()) + .collect(); + let host_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| hosts[i % hosts.len()].to_string()) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(d_dkey_vals)) as ArrayRef, + Arc::new(StringArray::from(env_vals)), + Arc::new(StringArray::from(service_vals)), + Arc::new(StringArray::from(host_vals)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); +} + +struct BenchVariant { + name: &'static str, + preserve_file_partitions: usize, + prefer_existing_sort: bool, +} + +const BENCH_VARIANTS: [BenchVariant; 3] = [ + BenchVariant { + name: "with_optimization", + preserve_file_partitions: 1, + prefer_existing_sort: false, + }, + BenchVariant { + name: "prefer_existing_sort", + preserve_file_partitions: 0, + prefer_existing_sort: true, + }, + BenchVariant { + name: "without_optimization", + preserve_file_partitions: 0, + prefer_existing_sort: false, + }, +]; + +async fn save_plans( + output_file: &Path, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: Option>>, +) { + let mut file = File::create(output_file).unwrap(); + writeln!(file, "Query: {query}\n").unwrap(); + + for variant in &BENCH_VARIANTS { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + variant.preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + variant.prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![("f_dkey".to_string(), DataType::Utf8)], + ..Default::default() + }; + if let Some(ref order) = file_sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", fact_path, fact_options) + .await + .unwrap(); + + if let Some(dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(query).await.unwrap(); + let plan = df.explain(false, false).unwrap().collect().await.unwrap(); + writeln!(file, "=== {} ===", variant.name).unwrap(); + writeln!(file, "{}\n", pretty_format_batches(&plan).unwrap()).unwrap(); + } +} + +#[expect(clippy::too_many_arguments)] +fn run_benchmark( + c: &mut Criterion, + rt: &Runtime, + name: &str, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: &Option>>, +) { + if std::env::var("SAVE_PLANS").is_ok() { + let output_path = format!("{name}_plans.txt"); + rt.block_on(save_plans( + Path::new(&output_path), + fact_path, + dim_path, + target_partitions, + query, + file_sort_order.clone(), + )); + println!("Plans saved to {output_path}"); + } + + let mut group = c.benchmark_group(name); + + for variant in &BENCH_VARIANTS { + let fact_path_owned = fact_path.to_string(); + let dim_path_owned = dim_path.map(|s| s.to_string()); + let sort_order = file_sort_order.clone(); + let query_owned = query.to_string(); + let preserve_file_partitions = variant.preserve_file_partitions; + let prefer_existing_sort = variant.prefer_existing_sort; + + group.bench_function(variant.name, |b| { + b.to_async(rt).iter(|| { + let fact_path = fact_path_owned.clone(); + let dim_path = dim_path_owned.clone(); + let sort_order = sort_order.clone(); + let query = query_owned.clone(); + async move { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![( + "f_dkey".to_string(), + DataType::Utf8, + )], + ..Default::default() + }; + if let Some(ref order) = sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", &fact_path, fact_options) + .await + .unwrap(); + + if let Some(ref dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(&query).await.unwrap(); + df.collect().await.unwrap() + } + }) + }); + } + + group.finish(); +} + +/// Aggregate on high-cardinality partitions which eliminates repartition and sort. +/// +/// Query: SELECT f_dkey, COUNT(*), SUM(value) FROM fact GROUP BY f_dkey ORDER BY f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ No repartitioning needed │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]) │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ FinalPartitioned │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey ASC] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, COUNT(*) as cnt, SUM(value) as total \ + FROM fact \ + GROUP BY f_dkey \ + ORDER BY f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order", + hc_fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Join and aggregate on partition column which demonstrates propagation through join. +/// +/// Query: SELECT f.f_dkey, MAX(d.env), ... FROM fact f JOIN dimension d ON f.f_dkey = d.d_dkey +/// WHERE d.service = 'log' GROUP BY f.f_dkey ORDER BY f.f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ Hash partitioning propagates through join │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ Hash partitioning preserved on probe side │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact, N groups) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ preserve_order=true │ Uses k-way merge to maintain sort, has overhead │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_join_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + dim_path: &str, + target_partitions: usize, +) { + let query = "SELECT f.f_dkey, MAX(d.env), MAX(d.service), COUNT(*), SUM(f.value) \ + FROM fact f \ + INNER JOIN dimension d ON f.f_dkey = d.d_dkey \ + WHERE d.service = 'log' \ + GROUP BY f.f_dkey \ + ORDER BY f.f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order_join", + hc_fact_path, + Some(dim_path), + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Window function with LIMIT which demonstrates partition and sort elimination. +/// +/// Query: SELECT f_dkey, timestamp, value, +/// ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn +/// FROM fact LIMIT 1000 +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ │ (LIMIT 1000) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ No repaartition needed │ +/// │ │ PARTITION BY f_dkey │ │ +/// │ │ ORDER BY timestamp │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey, timestamp] │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey, timestamp] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_window_bench( + c: &mut Criterion, + rt: &Runtime, + fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, timestamp, value, \ + ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn \ + FROM fact \ + LIMIT 1000"; + + let file_sort_order = vec![vec![ + col("f_dkey").sort(true, false), + col("timestamp").sort(true, false), + ]]; + + run_benchmark( + c, + rt, + "preserve_order_window", + fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +fn benchmark_main(c: &mut Criterion) { + let config = BenchConfig::from_env(); + let hc_config = BenchConfig::high_cardinality(&config); + + println!("\n=== Preserve File Partitioning Benchmark ==="); + println!( + "Normal config: {} partitions × {} rows = {} total rows", + config.fact_partitions, + config.rows_per_partition, + config.total_rows() + ); + println!( + "High-cardinality config: {} partitions × {} rows = {} total rows", + hc_config.fact_partitions, + hc_config.rows_per_partition, + hc_config.total_rows() + ); + println!("Target partitions: {}\n", config.target_partitions); + + let tmp_dir = TempDir::new().unwrap(); + println!("Generating data..."); + + // High-cardinality fact table + generate_fact_table( + tmp_dir.path(), + hc_config.fact_partitions, + hc_config.rows_per_partition, + ); + let hc_fact_dir = tmp_dir.path().join("fact_hc"); + fs::rename(tmp_dir.path().join("fact"), &hc_fact_dir).unwrap(); + let hc_fact_path = hc_fact_dir.to_str().unwrap().to_string(); + + // Normal fact table + generate_fact_table( + tmp_dir.path(), + config.fact_partitions, + config.rows_per_partition, + ); + let fact_path = tmp_dir.path().join("fact").to_str().unwrap().to_string(); + + // Dimension table (for join) + generate_dimension_table(tmp_dir.path(), hc_config.fact_partitions); + let dim_path = tmp_dir + .path() + .join("dimension") + .to_str() + .unwrap() + .to_string(); + + println!("Done.\n"); + + let rt = Runtime::new().unwrap(); + + preserve_order_bench(c, &rt, &hc_fact_path, hc_config.target_partitions); + preserve_order_join_bench( + c, + &rt, + &hc_fact_path, + &dim_path, + hc_config.target_partitions, + ); + preserve_order_window_bench(c, &rt, &fact_path, config.target_partitions); +} + +criterion_group! { + name = benches; + config = { + let config = BenchConfig::from_env(); + Criterion::default() + .measurement_time(std::time::Duration::from_secs(config.measurement_time_secs)) + .sample_size(10) + }; + targets = benchmark_main +} +criterion_main!(benches); diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs index 139fb12c30947..d41085907dbc8 100644 --- a/datafusion/core/benches/push_down_filter.rs +++ b/datafusion/core/benches/push_down_filter.rs @@ -18,16 +18,16 @@ use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use bytes::{BufMut, BytesMut}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::config::ConfigOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::ObjectStore; +use object_store::{ObjectStore, ObjectStoreExt}; use parquet::arrow::ArrowWriter; use std::sync::Arc; diff --git a/datafusion/core/benches/range_and_generate_series.rs b/datafusion/core/benches/range_and_generate_series.rs new file mode 100644 index 0000000000000..10d560df0813e --- /dev/null +++ b/datafusion/core/benches/range_and_generate_series.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod data_utils; + +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::execution::context::SessionContext; +use parking_lot::Mutex; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +#[expect(clippy::needless_pass_by_value)] +fn query(ctx: Arc>, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context() -> Arc> { + let ctx = SessionContext::new(); + Arc::new(Mutex::new(ctx)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(); + let rt = Runtime::new().unwrap(); + + c.bench_function("range(1000000)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000)")) + }); + + c.bench_function("generate_series(1000000)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000)", + ) + }) + }); + + c.bench_function("range(0, 1000000, 5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(0, 1000000, 5)")) + }); + + c.bench_function("generate_series(0, 1000000, 5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(0, 1000000, 5)", + ) + }) + }); + + c.bench_function("range(1000000, 0, -5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000, 0, -5)")) + }); + + c.bench_function("generate_series(1000000, 0, -5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000, 0, -5)", + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/reset_plan_states.rs b/datafusion/core/benches/reset_plan_states.rs new file mode 100644 index 0000000000000..5afae7f43242d --- /dev/null +++ b/datafusion/core/benches/reset_plan_states.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_catalog::MemTable; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::execution_plan::reset_plan_states; +use tokio::runtime::Runtime; + +const NUM_FIELDS: usize = 1000; +const PREDICATE_LEN: usize = 50; + +static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new( + (0..NUM_FIELDS) + .map(|i| Arc::new(Field::new(format!("x_{i}"), DataType::Int64, false))) + .collect::(), + )) +}); + +fn col_name(i: usize) -> String { + format!("x_{i}") +} + +fn aggr_name(i: usize) -> String { + format!("aggr_{i}") +} + +fn physical_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> Arc { + rt.block_on(async { + ctx.sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap() + }) +} + +fn predicate(col_name: impl Fn(usize) -> String, len: usize) -> String { + let mut predicate = String::new(); + for i in 0..len { + if i > 0 { + predicate.push_str(" AND "); + } + predicate.push_str(&col_name(i)); + predicate.push_str(" = "); + predicate.push_str(&i.to_string()); + } + predicate +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT aggr1(col1) as aggr1, aggr2(col2) as aggr2 FROM t +/// WHERE p1 +/// HAVING p2 +/// ``` +/// +/// Where `p1` and `p2` some long predicates. +/// +fn query1() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in 0..NUM_FIELDS { + if i > 0 { + query.push_str(", "); + } + query.push_str("AVG("); + query.push_str(&col_name(i)); + query.push_str(") AS "); + query.push_str(&aggr_name(i)); + } + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query.push_str(" HAVING "); + query.push_str(&predicate(aggr_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t JOIN v ON t.a = v.a +/// WHERE p1 +/// ``` +/// +fn query2() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in (0..NUM_FIELDS).step_by(2) { + if i > 0 { + query.push_str(", "); + } + if (i / 2) % 2 == 0 { + query.push_str(&format!("t.{}", col_name(i))); + } else { + query.push_str(&format!("v.{}", col_name(i))); + } + } + query.push_str(" FROM t JOIN v ON t.x_0 = v.x_0 WHERE "); + + fn qualified_name(i: usize) -> String { + format!("t.{}", col_name(i)) + } + + query.push_str(&predicate(qualified_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t +/// WHERE p +/// ``` +/// +fn query3() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + + // Create non-trivial projection. + for i in 0..NUM_FIELDS / 2 { + if i > 0 { + query.push_str(", "); + } + query.push_str(&col_name(i * 2)); + query.push_str(" + "); + query.push_str(&col_name(i * 2 + 1)); + } + + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query +} + +fn run_reset_states(b: &mut criterion::Bencher, plan: &Arc) { + b.iter(|| std::hint::black_box(reset_plan_states(Arc::clone(plan)).unwrap())); +} + +/// Benchmark is intended to measure overhead of actions, required to perform +/// making an independent instance of the execution plan to re-execute it, avoiding +/// re-planning stage. +fn bench_reset_plan_states(c: &mut Criterion) { + env_logger::init(); + + let rt = Runtime::new().unwrap(); + let ctx = SessionContext::new(); + ctx.register_table( + "t", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + ctx.register_table( + "v", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + macro_rules! bench_query { + ($query_producer: expr) => {{ + let sql = $query_producer(); + let plan = physical_plan(&ctx, &rt, &sql); + log::debug!("plan:\n{}", displayable(plan.as_ref()).indent(true)); + move |b| run_reset_states(b, &plan) + }}; + } + + c.bench_function("query1", bench_query!(query1)); + c.bench_function("query2", bench_query!(query2)); + c.bench_function("query3", bench_query!(query3)); +} + +criterion_group!(benches, bench_reset_plan_states); +criterion_main!(benches); diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 540f7212e96e9..d06ed3f28b743 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::scalar::ScalarValue; fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 276151e253f7e..7544f7ae26d43 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -78,18 +78,18 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ execution::context::TaskContext, physical_plan::{ + ExecutionPlan, ExecutionPlanProperties, coalesce_partitions::CoalescePartitionsExec, - sorts::sort_preserving_merge::SortPreservingMergeExec, ExecutionPlan, - ExecutionPlanProperties, + sorts::sort_preserving_merge::SortPreservingMergeExec, }, prelude::SessionContext, }; use datafusion_datasource::memory::MemorySourceConfig; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; /// Benchmarks for SortPreservingMerge stream -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use futures::StreamExt; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -102,61 +102,104 @@ const NUM_STREAMS: usize = 8; /// The size of each batch within each stream const BATCH_SIZE: usize = 1024; -/// Total number of input rows to generate -const INPUT_SIZE: u64 = 100000; +/// Input sizes to benchmark. The small size (100K) exercises the +/// in-memory concat-and-sort path; the large size (10M) exercises +/// the sort-then-merge path with high fan-in. +const INPUT_SIZES: &[(u64, &str)] = &[(100_000, "100k"), (1_000_000, "1M")]; type PartitionedBatches = Vec>; +type StreamGenerator = Box PartitionedBatches>; fn criterion_benchmark(c: &mut Criterion) { - let cases: Vec<(&str, &dyn Fn(bool) -> PartitionedBatches)> = vec![ - ("i64", &i64_streams), - ("f64", &f64_streams), - ("utf8 low cardinality", &utf8_low_cardinality_streams), - ("utf8 high cardinality", &utf8_high_cardinality_streams), - ( - "utf8 view low cardinality", - &utf8_view_low_cardinality_streams, - ), - ( - "utf8 view high cardinality", - &utf8_view_high_cardinality_streams, - ), - ("utf8 tuple", &utf8_tuple_streams), - ("utf8 view tuple", &utf8_view_tuple_streams), - ("utf8 dictionary", &dictionary_streams), - ("utf8 dictionary tuple", &dictionary_tuple_streams), - ("mixed dictionary tuple", &mixed_dictionary_tuple_streams), - ("mixed tuple", &mixed_tuple_streams), - ( - "mixed tuple with utf8 view", - &mixed_tuple_with_utf8_view_streams, - ), - ]; - - for (name, f) in cases { - c.bench_function(&format!("merge sorted {name}"), |b| { - let data = f(true); - let case = BenchCase::merge_sorted(&data); - b.iter(move || case.run()) - }); - - c.bench_function(&format!("sort merge {name}"), |b| { - let data = f(false); - let case = BenchCase::sort_merge(&data); - b.iter(move || case.run()) - }); - - c.bench_function(&format!("sort {name}"), |b| { - let data = f(false); - let case = BenchCase::sort(&data); - b.iter(move || case.run()) - }); - - c.bench_function(&format!("sort partitioned {name}"), |b| { - let data = f(false); - let case = BenchCase::sort_partitioned(&data); - b.iter(move || case.run()) - }); + for &(input_size, size_label) in INPUT_SIZES { + let cases: Vec<(&str, StreamGenerator)> = vec![ + ( + "i64", + Box::new(move |sorted| i64_streams(sorted, input_size)), + ), + ( + "f64", + Box::new(move |sorted| f64_streams(sorted, input_size)), + ), + ( + "utf8 low cardinality", + Box::new(move |sorted| utf8_low_cardinality_streams(sorted, input_size)), + ), + ( + "utf8 high cardinality", + Box::new(move |sorted| utf8_high_cardinality_streams(sorted, input_size)), + ), + ( + "utf8 view low cardinality", + Box::new(move |sorted| { + utf8_view_low_cardinality_streams(sorted, input_size) + }), + ), + ( + "utf8 view high cardinality", + Box::new(move |sorted| { + utf8_view_high_cardinality_streams(sorted, input_size) + }), + ), + ( + "utf8 tuple", + Box::new(move |sorted| utf8_tuple_streams(sorted, input_size)), + ), + ( + "utf8 view tuple", + Box::new(move |sorted| utf8_view_tuple_streams(sorted, input_size)), + ), + ( + "utf8 dictionary", + Box::new(move |sorted| dictionary_streams(sorted, input_size)), + ), + ( + "utf8 dictionary tuple", + Box::new(move |sorted| dictionary_tuple_streams(sorted, input_size)), + ), + ( + "mixed dictionary tuple", + Box::new(move |sorted| { + mixed_dictionary_tuple_streams(sorted, input_size) + }), + ), + ( + "mixed tuple", + Box::new(move |sorted| mixed_tuple_streams(sorted, input_size)), + ), + ( + "mixed tuple with utf8 view", + Box::new(move |sorted| { + mixed_tuple_with_utf8_view_streams(sorted, input_size) + }), + ), + ]; + + for (name, f) in &cases { + c.bench_function(&format!("merge sorted {name} {size_label}"), |b| { + let data = f(true); + let case = BenchCase::merge_sorted(&data); + b.iter(move || case.run()) + }); + + c.bench_function(&format!("sort merge {name} {size_label}"), |b| { + let data = f(false); + let case = BenchCase::sort_merge(&data); + b.iter(move || case.run()) + }); + + c.bench_function(&format!("sort {name} {size_label}"), |b| { + let data = f(false); + let case = BenchCase::sort(&data); + b.iter(move || case.run()) + }); + + c.bench_function(&format!("sort partitioned {name} {size_label}"), |b| { + let data = f(false); + let case = BenchCase::sort_partitioned(&data); + b.iter(move || case.run()) + }); + } } } @@ -279,8 +322,8 @@ fn make_sort_exprs(schema: &Schema) -> LexOrdering { } /// Create streams of int64 (where approximately 1/3 values is repeated) -fn i64_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().i64_values(); +fn i64_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).i64_values(); if sorted { values.sort_unstable(); } @@ -293,8 +336,8 @@ fn i64_streams(sorted: bool) -> PartitionedBatches { /// Create streams of f64 (where approximately 1/3 values are repeated) /// with the same distribution as i64_streams -fn f64_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().f64_values(); +fn f64_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).f64_values(); if sorted { values.sort_unstable_by(|a, b| a.total_cmp(b)); } @@ -306,8 +349,8 @@ fn f64_streams(sorted: bool) -> PartitionedBatches { } /// Create streams of random low cardinality utf8 values -fn utf8_low_cardinality_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().utf8_low_cardinality_values(); +fn utf8_low_cardinality_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).utf8_low_cardinality_values(); if sorted { values.sort_unstable(); } @@ -318,8 +361,11 @@ fn utf8_low_cardinality_streams(sorted: bool) -> PartitionedBatches { } /// Create streams of random low cardinality utf8_view values -fn utf8_view_low_cardinality_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().utf8_low_cardinality_values(); +fn utf8_view_low_cardinality_streams( + sorted: bool, + input_size: u64, +) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).utf8_low_cardinality_values(); if sorted { values.sort_unstable(); } @@ -330,8 +376,11 @@ fn utf8_view_low_cardinality_streams(sorted: bool) -> PartitionedBatches { } /// Create streams of high cardinality (~ no duplicates) utf8_view values -fn utf8_view_high_cardinality_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().utf8_high_cardinality_values(); +fn utf8_view_high_cardinality_streams( + sorted: bool, + input_size: u64, +) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).utf8_high_cardinality_values(); if sorted { values.sort_unstable(); } @@ -342,8 +391,8 @@ fn utf8_view_high_cardinality_streams(sorted: bool) -> PartitionedBatches { } /// Create streams of high cardinality (~ no duplicates) utf8 values -fn utf8_high_cardinality_streams(sorted: bool) -> PartitionedBatches { - let mut values = DataGenerator::new().utf8_high_cardinality_values(); +fn utf8_high_cardinality_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut values = DataGenerator::new(input_size).utf8_high_cardinality_values(); if sorted { values.sort_unstable(); } @@ -354,15 +403,15 @@ fn utf8_high_cardinality_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (utf8_low, utf8_low, utf8_high) -fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); +fn utf8_tuple_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -387,15 +436,15 @@ fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (utf8_view_low, utf8_view_low, utf8_view_high) -fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); +fn utf8_view_tuple_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -420,16 +469,16 @@ fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (f64, utf8_low, utf8_low, i64) -fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); +fn mixed_tuple_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -458,16 +507,19 @@ fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (f64, utf8_view_low, utf8_view_low, i64) -fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); +fn mixed_tuple_with_utf8_view_streams( + sorted: bool, + input_size: u64, +) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -496,9 +548,9 @@ fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (utf8_dict) -fn dictionary_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut values = gen.utf8_low_cardinality_values(); +fn dictionary_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); + let mut values = data_gen.utf8_low_cardinality_values(); if sorted { values.sort_unstable(); } @@ -511,13 +563,13 @@ fn dictionary_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (utf8_dict, utf8_dict, utf8_dict) -fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen +fn dictionary_tuple_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) .collect(); if sorted { @@ -542,14 +594,14 @@ fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { } /// Create a batch of (utf8_dict, utf8_dict, utf8_dict, i64) -fn mixed_dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen +fn mixed_dictionary_tuple_streams(sorted: bool, input_size: u64) -> PartitionedBatches { + let mut data_gen = DataGenerator::new(input_size); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -579,19 +631,21 @@ fn mixed_dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { /// Encapsulates creating data for this test struct DataGenerator { rng: StdRng, + input_size: u64, } impl DataGenerator { - fn new() -> Self { + fn new(input_size: u64) -> Self { Self { rng: StdRng::seed_from_u64(42), + input_size, } } /// Create an array of i64 sorted values (where approximately 1/3 values is repeated) fn i64_values(&mut self) -> Vec { - let mut vec: Vec<_> = (0..INPUT_SIZE) - .map(|_| self.rng.random_range(0..INPUT_SIZE as i64)) + let mut vec: Vec<_> = (0..self.input_size) + .map(|_| self.rng.random_range(0..self.input_size as i64)) .collect(); vec.sort_unstable(); @@ -614,7 +668,7 @@ impl DataGenerator { .collect::>(); // pick from the 100 strings randomly - let mut input = (0..INPUT_SIZE) + let mut input = (0..self.input_size) .map(|_| { let idx = self.rng.random_range(0..strings.len()); let s = Arc::clone(&strings[idx]); @@ -629,7 +683,7 @@ impl DataGenerator { /// Create sorted values of high cardinality (~ no duplicates) utf8 values fn utf8_high_cardinality_values(&mut self) -> Vec> { // make random strings - let mut input = (0..INPUT_SIZE) + let mut input = (0..self.input_size) .map(|_| Some(self.random_string())) .collect::>(); diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index e535a018161f1..54cd9a0bcd547 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, @@ -27,9 +25,6 @@ use datafusion::prelude::SessionConfig; use parking_lot::Mutex; use std::sync::Arc; -extern crate arrow; -extern crate datafusion; - use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; @@ -37,6 +32,7 @@ use datafusion::execution::context::SessionContext; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); @@ -97,8 +93,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().first().unwrap().clone(); - ctx + ctx_holder.lock().first().unwrap().clone() } fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index ecc3f908d4b15..afd384f7b170e 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -20,13 +20,13 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::{collect, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect}; use criterion::async_executor::FuturesExecutor; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_datasource::memory::MemorySourceConfig; fn generate_spm_for_round_robin_tie_breaker( @@ -66,10 +66,9 @@ fn generate_spm_for_round_robin_tie_breaker( RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() }; - let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); - let partitions = vec![rbs.clone(); partition_count]; - let schema = rb.schema(); + let rbs = std::iter::repeat_n(rb, batch_count).collect::>(); + let partitions = vec![rbs.clone(); partition_count]; let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 6266a7184cf51..fcc8da30fedd9 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,27 +15,26 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use arrow::array::PrimitiveArray; use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::ArrowNativeTypeOp; +use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{DataType, Field, Fields, Schema}; use criterion::Bencher; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; -use datafusion_common::{config::Dialect, ScalarValue}; +use datafusion_common::{ScalarValue, config::Dialect}; use datafusion_expr::col; +use rand_distr::num_traits::NumCast; use std::hint::black_box; use std::path::PathBuf; use std::sync::Arc; +use test_utils::TableDef; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; -use test_utils::TableDef; use tokio::runtime::Runtime; const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; @@ -74,6 +73,21 @@ fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc Arc { + let struct_fields = Fields::from(vec![ + Field::new("value", DataType::Int32, true), + Field::new("label", DataType::Utf8, true), + ]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("props", DataType::Struct(struct_fields), true), + ])); + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -84,11 +98,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t1000", create_table_provider("d", 1000)) .unwrap(); + ctx.register_table("struct_t1", create_struct_table_provider()) + .unwrap(); + ctx.register_table("struct_t2", create_struct_table_provider()) + .unwrap(); ctx } /// Register the table definitions as a MemTable with the context and return the /// context +#[expect(clippy::needless_pass_by_value)] fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { defs.iter().for_each(|TableDef { name, schema }| { ctx.register_table( @@ -111,10 +130,27 @@ fn register_clickbench_hits_table(rt: &Runtime) -> SessionContext { format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}") }; - let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + let sql = + format!("CREATE EXTERNAL TABLE hits_raw STORED AS PARQUET LOCATION '{path}'"); + // ClickBench partitioned dataset was written by an ancient version of pyarrow that + // that wrote strings with the wrong logical type. To read it correctly, we must + // automatically convert binary to string. + rt.block_on(ctx.sql("SET datafusion.execution.parquet.binary_as_string = true;")) + .unwrap(); rt.block_on(ctx.sql(&sql)).unwrap(); + // ClickBench stores EventDate as UInt16 (days since 1970-01-01). Create a view + // that exposes it as SQL DATE so that queries comparing it with date literals + // (e.g. "EventDate >= '2013-07-01'") work correctly during planning. + rt.block_on(ctx.sql( + "CREATE VIEW hits AS \ + SELECT * EXCEPT (\"EventDate\"), \ + CAST(CAST(\"EventDate\" AS INTEGER) AS DATE) AS \"EventDate\" \ + FROM hits_raw", + )) + .unwrap(); + let count = rt.block_on(async { ctx.table("hits").await.unwrap().count().await.unwrap() }); assert!(count > 0); @@ -155,18 +191,30 @@ fn benchmark_with_param_values_many_columns( /// 0,100...9900 /// 0,200...19800 /// 0,300...29700 -fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows: usize) { - // ("c0", [0, 0, ...]) - // ("c1": [100, 200, ...]) - // etc - let iter = (0..num_columns).map(|i| i as u64).map(|i| { - let array: ArrayRef = Arc::new(arrow::array::UInt64Array::from_iter_values( - (0..num_rows) - .map(|j| j as u64 * 100 + i) - .collect::>(), - )); +fn register_union_order_table_generic( + ctx: &SessionContext, + num_columns: usize, + num_rows: usize, +) where + T: ArrowPrimitiveType, + T::Native: ArrowNativeTypeOp + NumCast, +{ + let iter = (0..num_columns).map(|i| { + let array_data: Vec = (0..num_rows) + .map(|j| { + let value = (j as u64) * 100 + (i as u64); + ::from(value).unwrap_or_else(|| { + panic!("Failed to cast numeric value to Native type") + }) + }) + .collect(); + + // Use PrimitiveArray which is generic over the ArrowPrimitiveType T + let array: ArrayRef = Arc::new(PrimitiveArray::::from_iter_values(array_data)); + (format!("c{i}"), array) }); + let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); let partitions = vec![vec![batch]]; @@ -183,7 +231,6 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows ctx.register_table("t", Arc::new(table)).unwrap(); } - /// return a query like /// ```sql /// select c1, 2 as c2, ... n as cn from t ORDER BY c1 @@ -226,8 +273,10 @@ fn criterion_benchmark(c: &mut Criterion) { if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() { - panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ - 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + panic!( + "benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark" + ) } let ctx = create_context(); @@ -401,15 +450,61 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); + let struct_agg_sort_query = "SELECT \ + struct_t1.props['label'], \ + SUM(struct_t1.props['value']), \ + MAX(struct_t2.props['value']), \ + COUNT(*) \ + FROM struct_t1 \ + JOIN struct_t2 ON struct_t1.id = struct_t2.id \ + WHERE struct_t1.props['value'] > 50 \ + GROUP BY struct_t1.props['label'] \ + ORDER BY SUM(struct_t1.props['value']) DESC"; + + // -- Struct column benchmarks -- + c.bench_function("logical_plan_struct_join_agg_sort", |b| { + b.iter(|| logical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + c.bench_function("physical_plan_struct_join_agg_sort", |b| { + b.iter(|| physical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + // -- Sorted Queries -- // 100, 200 && 300 is taking too long - https://github.com/apache/datafusion/issues/18366 + // Logical Plan for datatype Int64 and UInt64 differs, UInt64 Logical Plan's Union are wrapped + // up in Projection, and EliminateNestedUnion OptimezerRule is not applied leading to significantly + // longer execution time. + // https://github.com/apache/datafusion/issues/17261 + for column_count in [10, 50 /* 100, 200, 300 */] { - register_union_order_table(&ctx, column_count, 1000); + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); // this query has many expressions in its sort order so stresses // order equivalence validation c.bench_function( - &format!("physical_sorted_union_order_by_{column_count}"), + &format!("physical_sorted_union_order_by_{column_count}_int64"), + |b| { + // SELECT ... UNION ALL ... + let query = union_orderby_query(column_count); + b.iter(|| physical_plan(&ctx, &rt, &query)) + }, + ); + + let _ = ctx.deregister_table("t"); + } + + for column_count in [10, 50 /* 100, 200, 300 */] { + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); + c.bench_function( + &format!("physical_sorted_union_order_by_{column_count}_uint64"), |b| { // SELECT ... UNION ALL ... let query = union_orderby_query(column_count); @@ -477,9 +572,6 @@ fn criterion_benchmark(c: &mut Criterion) { }; let raw_tpcds_sql_queries = (1..100) - // skip query 75 until it is fixed - // https://github.com/apache/datafusion/issues/17801 - .filter(|q| *q != 75) .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index aff7cb4d101d5..d4955313c79c3 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -18,7 +18,7 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow_schema::DataType; use arrow_schema::TimeUnit::Nanosecond; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; @@ -27,6 +27,7 @@ use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ btrim, length, regexp_like, regexp_replace, to_timestamp, upper, }; +use std::fmt::Write; use std::hint::black_box; use std::ops::Rem; use std::sync::Arc; @@ -212,14 +213,127 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { }) } -fn criterion_benchmark(c: &mut Criterion) { +/// Build a CASE-heavy dataframe over a non-inner join to stress +/// planner-time filter pushdown and nullability/type inference. +fn build_case_heavy_left_join_df(ctx: &SessionContext, rt: &Runtime) -> DataFrame { + register_string_table(ctx, 100, 1000); + let query = build_case_heavy_left_join_query(30, 1); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) -> String { + let mut query = String::from( + "SELECT l.c0, r.c0 AS rc0 FROM t l LEFT JOIN t r ON l.c0 = r.c0 WHERE ", + ); + + if predicate_count == 0 { + query.push_str("TRUE"); + return query; + } + + // Keep this deterministic so comparisons between profiles are stable. + for i in 0..predicate_count { + if i > 0 { + query.push_str(" AND "); + } + + let mut expr = format!("length(l.c{})", i % 20); + for depth in 0..case_depth { + let left_col = (i + depth + 1) % 20; + let right_col = (i + depth + 2) % 20; + expr = format!( + "CASE WHEN l.c{left_col} IS NOT NULL THEN {expr} ELSE length(r.c{right_col}) END" + ); + } + + let _ = write!(&mut query, "{expr} > 2"); + } + + query +} + +fn build_case_heavy_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + case_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + let ctx = SessionContext::new(); + register_string_table(&ctx, 100, 1000); + if !push_down_filter_enabled { + let removed = ctx.remove_optimizer_rule("push_down_filter"); + assert!( + removed, + "push_down_filter rule should be present in the default optimizer" + ); + } + + let query = build_case_heavy_left_join_query(predicate_count, case_depth); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn build_non_case_left_join_query( + predicate_count: usize, + nesting_depth: usize, +) -> String { + let mut query = String::from( + "SELECT l.c0, r.c0 AS rc0 FROM t l LEFT JOIN t r ON l.c0 = r.c0 WHERE ", + ); + + if predicate_count == 0 { + query.push_str("TRUE"); + return query; + } + + // Keep this deterministic so comparisons between profiles are stable. + for i in 0..predicate_count { + if i > 0 { + query.push_str(" AND "); + } + + let left_col = i % 20; + let mut expr = format!("l.c{left_col}"); + for depth in 0..nesting_depth { + let right_col = (i + depth + 1) % 20; + expr = format!("coalesce({expr}, r.c{right_col})"); + } + + let _ = write!(&mut query, "length({expr}) > 2"); + } + + query +} + +fn build_non_case_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + nesting_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { let ctx = SessionContext::new(); + register_string_table(&ctx, 100, 1000); + if !push_down_filter_enabled { + let removed = ctx.remove_optimizer_rule("push_down_filter"); + assert!( + removed, + "push_down_filter rule should be present in the default optimizer" + ); + } + + let query = build_non_case_left_join_query(predicate_count, nesting_depth); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn criterion_benchmark(c: &mut Criterion) { + let baseline_ctx = SessionContext::new(); + let case_heavy_ctx = SessionContext::new(); let rt = Runtime::new().unwrap(); // validate logical plan optimize performance // https://github.com/apache/datafusion/issues/17261 - let df = build_test_data_frame(&ctx, &rt); + let df = build_test_data_frame(&baseline_ctx, &rt); + let case_heavy_left_join_df = build_case_heavy_left_join_df(&case_heavy_ctx, &rt); c.bench_function("logical_plan_optimize", |b| { b.iter(|| { @@ -227,6 +341,125 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(rt.block_on(async { df_clone.into_optimized_plan().unwrap() })); }) }); + + c.bench_function("logical_plan_optimize_hotspot_case_heavy_left_join", |b| { + b.iter(|| { + let df_clone = case_heavy_left_join_df.clone(); + black_box(rt.block_on(async { df_clone.into_optimized_plan().unwrap() })); + }) + }); + + let predicate_sweep = [10, 20, 30, 40, 60]; + let case_depth_sweep = [1, 2, 3]; + + let mut hotspot_group = + c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); + for case_depth in case_depth_sweep { + for predicate_count in predicate_sweep { + let with_push_down_filter = + build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + true, + ); + let without_push_down_filter = + build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + false, + ); + + let input_label = + format!("predicates={predicate_count},case_depth={case_depth}"); + // A/B interpretation: + // - with_push_down_filter: default optimizer path (rule enabled) + // - without_push_down_filter: control path with the rule removed + // Compare both IDs at the same sweep point to isolate rule impact. + hotspot_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + hotspot_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + } + } + hotspot_group.finish(); + + let mut control_group = + c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); + for nesting_depth in case_depth_sweep { + for predicate_count in predicate_sweep { + let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + true, + ); + let without_push_down_filter = + build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + false, + ); + + let input_label = + format!("predicates={predicate_count},nesting_depth={nesting_depth}"); + control_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + control_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + } + } + control_group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58797dfed6b67..fc8caf31acd11 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -20,7 +20,7 @@ use std::{fmt::Write, sync::Arc, time::Duration}; use arrow::array::{Int64Builder, RecordBatch, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion::{ datasource::{ file_format::parquet::ParquetFormat, @@ -31,13 +31,13 @@ use datafusion::{ use datafusion_execution::runtime_env::RuntimeEnv; use itertools::Itertools; use object_store::{ + ObjectStore, ObjectStoreExt, memory::InMemory, path::Path, throttle::{ThrottleConfig, ThrottledStore}, - ObjectStore, }; use parquet::arrow::ArrowWriter; -use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, rngs::StdRng}; use tokio::runtime::Runtime; use url::Url; diff --git a/datafusion/core/benches/struct_query_sql.rs b/datafusion/core/benches/struct_query_sql.rs index 5c7b427310827..96434fc379ea6 100644 --- a/datafusion/core/benches/struct_query_sql.rs +++ b/datafusion/core/benches/struct_query_sql.rs @@ -20,7 +20,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 9a5fb7163be5c..c78b1ea494407 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -17,26 +17,70 @@ mod data_utils; +use arrow::array::Int64Builder; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; -use datafusion_execution::TaskContext; +use rand::SeedableRng; +use rand::seq::SliceRandom; use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +const LIMIT: usize = 10; + +/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids +/// This ensures consistent results across benchmark runs +fn make_distinct_data( + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Vec>)> { + let mut rng = rand::rngs::SmallRng::from_seed([42; 32]); + let total_samples = partition_cnt as usize * sample_cnt as usize; + let mut ids = Vec::new(); + for i in 0..total_samples { + ids.push(i as i64); + } + ids.shuffle(&mut rng); + + let mut global_idx = 0; + let schema = test_distinct_schema(); + let mut partitions = vec![]; + for _ in 0..partition_cnt { + let mut id_builder = Int64Builder::new(); + + for _ in 0..sample_cnt { + let id = ids[global_idx]; + id_builder.append_value(id); + global_idx += 1; + } + + let id_col = Arc::new(id_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?; + partitions.push(vec![batch]); + } + + Ok((schema, partitions)) +} + +/// Returns a Schema for distinct benchmarks with i64 trace_id +fn test_distinct_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])) +} + async fn create_context( - limit: usize, partition_cnt: i32, sample_cnt: i32, asc: bool, use_topk: bool, use_view: bool, -) -> Result<(Arc, Arc)> { +) -> Result { let (schema, parts) = make_data(partition_cnt, sample_cnt, asc, use_view).unwrap(); let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); @@ -46,165 +90,408 @@ async fn create_context( opts.optimizer.enable_topk_aggregation = use_topk; let ctx = SessionContext::new_with_config(cfg); let _ = ctx.register_table("traces", mem_table)?; - let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + + Ok(ctx) +} + +async fn create_context_distinct( + partition_cnt: i32, + sample_cnt: i32, + use_topk: bool, +) -> Result { + // Use deterministic data generation for DISTINCT queries to ensure consistent results + let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + +fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { + black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); +} + +fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) { + black_box(rt.block_on(async { aggregate_string(ctx, limit, use_topk).await })) + .unwrap(); +} + +fn run_distinct( + rt: &Runtime, + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) { + black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await })) + .unwrap(); +} + +async fn aggregate( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { + let sql = format!( + "select max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};" + ); let df = ctx.sql(sql.as_str()).await?; - let physical_plan = df.create_physical_plan().await?; - let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); assert_eq!( actual_phys_plan.contains(&format!("lim=[{limit}]")), use_topk ); - Ok((physical_plan, ctx.task_ctx())) + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + let expected_asc = r#" ++--------------------------+ +| max(traces.timestamp_ms) | ++--------------------------+ +| 16909009999999 | +| 16909009999998 | +| 16909009999997 | +| 16909009999996 | +| 16909009999995 | +| 16909009999994 | +| 16909009999993 | +| 16909009999992 | +| 16909009999991 | +| 16909009999990 | ++--------------------------+ + "# + .trim(); + if asc { + assert_eq!(actual.trim(), expected_asc); + } + + Ok(()) } -fn run(rt: &Runtime, plan: Arc, ctx: Arc, asc: bool) { - black_box(rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await })) - .unwrap(); +/// Benchmark for string aggregate functions with topk optimization. +/// This tests grouping by a numeric column (timestamp_ms) and aggregating +/// a string column (trace_id) with Utf8 or Utf8View data types. +async fn aggregate_string( + ctx: SessionContext, + limit: usize, + use_topk: bool, +) -> Result> { + let sql = format!( + "select max(trace_id) from traces group by timestamp_ms order by max(trace_id) desc limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + Ok(batches) } -async fn aggregate( - plan: Arc, - ctx: Arc, +async fn aggregate_distinct( + ctx: SessionContext, + limit: usize, + use_topk: bool, asc: bool, ) -> Result<()> { - let batches = collect(plan, ctx).await?; + let order_direction = if asc { "asc" } else { "desc" }; + let sql = format!( + "select id from traces group by id order by id {order_direction} limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + let batches = collect(plan, ctx.task_ctx()).await?; assert_eq!(batches.len(), 1); let batch = batches.first().unwrap(); - assert_eq!(batch.num_rows(), 10); + assert_eq!(batch.num_rows(), LIMIT); let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + let expected_asc = r#" -+----------------------------------+--------------------------+ -| trace_id | max(traces.timestamp_ms) | -+----------------------------------+--------------------------+ -| 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | -| 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | -| 02801bbe533190a9f8713d75222f445d | 16909009999997 | -| 9e31b3b5a620de32b68fefa5aeea57f1 | 16909009999996 | -| 2d88a860e9bd1cfaa632d8e7caeaa934 | 16909009999995 | -| a47edcef8364ab6f191dd9103e51c171 | 16909009999994 | -| 36a3fa2ccfbf8e00337f0b1254384db6 | 16909009999993 | -| 0756be84f57369012e10de18b57d8a2f | 16909009999992 | -| d4d6bf9845fa5897710e3a8db81d5907 | 16909009999991 | -| 3c2cc1abe728a66b61e14880b53482a0 | 16909009999990 | -+----------------------------------+--------------------------+ - "# ++----+ +| id | ++----+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----+ +"# + .trim(); + + let expected_desc = r#" ++---------+ +| id | ++---------+ +| 9999999 | +| 9999998 | +| 9999997 | +| 9999996 | +| 9999995 | +| 9999994 | +| 9999993 | +| 9999992 | +| 9999991 | +| 9999990 | ++---------+ +"# .trim(); + + // Verify exact results match expected values if asc { - assert_eq!(actual.trim(), expected_asc); + assert_eq!( + actual.trim(), + expected_asc, + "Ascending DISTINCT results do not match expected values" + ); + } else { + assert_eq!( + actual.trim(), + expected_desc, + "Descending DISTINCT results do not match expected values" + ); } Ok(()) } +struct BenchCase<'a> { + name_tpl: &'a str, + asc: bool, + use_topk: bool, + use_view: bool, +} + +struct StringCase { + asc: bool, + use_topk: bool, + use_view: bool, +} + +fn assert_utf8_utf8view_match( + rt: &Runtime, + partitions: i32, + samples: i32, + limit: usize, + asc: bool, + use_topk: bool, +) { + let ctx_utf8 = rt + .block_on(create_context(partitions, samples, asc, use_topk, false)) + .unwrap(); + let ctx_view = rt + .block_on(create_context(partitions, samples, asc, use_topk, true)) + .unwrap(); + let batches_utf8 = rt + .block_on(aggregate_string(ctx_utf8, limit, use_topk)) + .unwrap(); + let batches_view = rt + .block_on(aggregate_string(ctx_view, limit, use_topk)) + .unwrap(); + let result_utf8 = pretty_format_batches(&batches_utf8).unwrap().to_string(); + let result_view = pretty_format_batches(&batches_view).unwrap().to_string(); + assert_eq!( + result_utf8, result_view, + "Utf8 vs Utf8View mismatch for asc={asc}, use_topk={use_topk}" + ); +} + +fn assert_string_results_match( + rt: &Runtime, + partitions: i32, + samples: i32, + limit: usize, +) { + for asc in [false, true] { + for use_topk in [false, true] { + assert_utf8_utf8view_match(rt, partitions, samples, limit, asc, use_topk); + } + } +} + fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); - let limit = 10; + let limit = LIMIT; let partitions = 10; let samples = 1_000_000; + let total_rows = partitions * samples; - c.bench_function( - format!("aggregate {} time-series rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let real = rt.block_on(async { - create_context(limit, partitions, samples, false, false, false) - .await - .unwrap() - }); - run(&rt, real.0.clone(), real.1.clone(), false) - }) + // Numeric aggregate benchmarks + let numeric_cases = &[ + BenchCase { + name_tpl: "aggregate {rows} time-series rows", + asc: false, + use_topk: false, + use_view: false, }, - ); - - c.bench_function( - format!("aggregate {} worst-case rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let asc = rt.block_on(async { - create_context(limit, partitions, samples, true, false, false) - .await - .unwrap() - }); - run(&rt, asc.0.clone(), asc.1.clone(), true) - }) + BenchCase { + name_tpl: "aggregate {rows} worst-case rows", + asc: true, + use_topk: false, + use_view: false, }, - ); - - c.bench_function( - format!( - "top k={limit} aggregate {} time-series rows", - partitions * samples - ) - .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, false) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) + BenchCase { + name_tpl: "top k={limit} aggregate {rows} time-series rows", + asc: false, + use_topk: true, + use_view: false, }, - ); - - c.bench_function( - format!( - "top k={limit} aggregate {} worst-case rows", - partitions * samples - ) - .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, false) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) + BenchCase { + name_tpl: "top k={limit} aggregate {rows} worst-case rows", + asc: true, + use_topk: true, + use_view: false, }, - ); - - // Utf8View schema,time-series rows - c.bench_function( - format!( - "top k={limit} aggregate {} time-series rows [Utf8View]", - partitions * samples - ) - .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, true) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) + BenchCase { + name_tpl: "top k={limit} aggregate {rows} time-series rows [Utf8View]", + asc: false, + use_topk: true, + use_view: true, }, - ); + BenchCase { + name_tpl: "top k={limit} aggregate {rows} worst-case rows [Utf8View]", + asc: true, + use_topk: true, + use_view: true, + }, + ]; + for case in numeric_cases { + let name = case + .name_tpl + .replace("{rows}", &total_rows.to_string()) + .replace("{limit}", &limit.to_string()); + let ctx = rt + .block_on(create_context( + partitions, + samples, + case.asc, + case.use_topk, + case.use_view, + )) + .unwrap(); + c.bench_function(&name, |b| { + b.iter(|| run(&rt, ctx.clone(), limit, case.use_topk, case.asc)) + }); + } - // Utf8View schema,worst-case rows - c.bench_function( - format!( - "top k={limit} aggregate {} worst-case rows [Utf8View]", - partitions * samples - ) - .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, true) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) + assert_string_results_match(&rt, partitions, samples, limit); + + let string_cases = &[ + StringCase { + asc: false, + use_topk: false, + use_view: false, }, - ); + StringCase { + asc: true, + use_topk: false, + use_view: false, + }, + StringCase { + asc: false, + use_topk: false, + use_view: true, + }, + StringCase { + asc: true, + use_topk: false, + use_view: true, + }, + StringCase { + asc: false, + use_topk: true, + use_view: false, + }, + StringCase { + asc: true, + use_topk: true, + use_view: false, + }, + StringCase { + asc: false, + use_topk: true, + use_view: true, + }, + StringCase { + asc: true, + use_topk: true, + use_view: true, + }, + ]; + for case in string_cases { + let scenario = if case.asc { + "worst-case" + } else { + "time-series" + }; + let type_label = if case.use_view { "Utf8View" } else { "Utf8" }; + let name = if case.use_topk { + format!( + "top k={limit} string aggregate {total_rows} {scenario} rows [{type_label}]" + ) + } else { + format!("string aggregate {total_rows} {scenario} rows [{type_label}]") + }; + let ctx = rt + .block_on(create_context( + partitions, + samples, + case.asc, + case.use_topk, + case.use_view, + )) + .unwrap(); + c.bench_function(&name, |b| { + b.iter(|| run_string(&rt, ctx.clone(), limit, case.use_topk)) + }); + } + + // DISTINCT benchmarks + for use_topk in [false, true] { + let ctx = rt.block_on(async { + create_context_distinct(partitions, samples, use_topk) + .await + .unwrap() + }); + let topk_label = if use_topk { "TopK" } else { "no TopK" }; + for asc in [false, true] { + let dir = if asc { "asc" } else { "desc" }; + let name = format!("distinct {total_rows} rows {dir} [{topk_label}]"); + c.bench_function(&name, |b| { + b.iter(|| run_distinct(&rt, ctx.clone(), limit, use_topk, asc)) + }); + } + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/topk_repartition.rs b/datafusion/core/benches/topk_repartition.rs new file mode 100644 index 0000000000000..e1f14e4aaa633 --- /dev/null +++ b/datafusion/core/benches/topk_repartition.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for the TopKRepartition optimizer rule. +//! +//! Measures the benefit of pushing TopK (Sort with fetch) below hash +//! repartition when running partitioned window functions with LIMIT. + +mod data_utils; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use data_utils::create_table_provider; +use datafusion::prelude::{SessionConfig, SessionContext}; +use parking_lot::Mutex; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +#[expect(clippy::needless_pass_by_value)] +fn query(ctx: Arc>, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + target_partitions: usize, + enable_topk_repartition: bool, +) -> Arc> { + let array_len = 1024 * 1024; + let batch_size = 8 * 1024; + let mut config = SessionConfig::new().with_target_partitions(target_partitions); + config.options_mut().optimizer.enable_topk_repartition = enable_topk_repartition; + let ctx = SessionContext::new_with_config(config); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let provider = + create_table_provider(partitions_len, array_len, batch_size).unwrap(); + ctx.register_table("t", provider).unwrap(); + }); + Arc::new(Mutex::new(ctx)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limits = [10, 1_000, 10_000, 100_000]; + let scans = 16; + let target_partitions = 4; + + let group = format!("topk_repartition_{scans}_to_{target_partitions}"); + let mut group = c.benchmark_group(group); + for limit in limits { + let sql = format!( + "SELECT \ + SUM(f64) OVER (PARTITION BY u64_narrow ORDER BY u64_wide ROWS UNBOUNDED PRECEDING) \ + FROM t \ + ORDER BY u64_narrow, u64_wide \ + LIMIT {limit}" + ); + + let ctx_disabled = create_context(scans, target_partitions, false); + group.bench_function(BenchmarkId::new("disabled", limit), |b| { + b.iter(|| query(ctx_disabled.clone(), &rt, &sql)) + }); + + let ctx_enabled = create_context(scans, target_partitions, true); + group.bench_function(BenchmarkId::new("enabled", limit), |b| { + b.iter(|| query(ctx_enabled.clone(), &rt, &sql)) + }); + } + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index 6d83959f7eb3c..1657cae913fef 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -15,14 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; @@ -31,6 +26,7 @@ use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 63387c023b11a..c34865a32d532 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -16,14 +16,15 @@ // under the License. use datafusion::execution::SessionStateDefaults; -use datafusion_common::{not_impl_err, HashSet, Result}; +use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, HigherOrderUDF, ScalarUDF, WindowUDF, + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; use std::env::args; use std::fmt::Write as _; +use std::sync::Arc; /// Print documentation for all functions of a given type to stdout /// @@ -71,6 +72,10 @@ fn print_scalar_docs() -> Result { providers.push(Box::new(f.as_ref().clone())); } + for f in SessionStateDefaults::default_higher_order_functions() { + providers.push(Box::new(f)); + } + print_docs(providers, scalar_doc_sections::doc_sections()) } @@ -84,30 +89,7 @@ fn print_window_docs() -> Result { print_docs(providers, window_doc_sections::doc_sections()) } -// Temporary method useful to semi automate -// the migration of UDF documentation generation from code based -// to attribute based -// To be removed -#[allow(dead_code)] -fn save_doc_code_text(documentation: &Documentation, name: &str) { - let attr_text = documentation.to_doc_attribute(); - - let file_path = format!("{name}.txt"); - if std::path::Path::new(&file_path).exists() { - std::fs::remove_file(&file_path).unwrap(); - } - - // Open the file in append mode, create it if it doesn't exist - let mut file = std::fs::OpenOptions::new() - .append(true) // Open in append mode - .create(true) // Create the file if it doesn't exist - .open(file_path) - .unwrap(); - - use std::io::Write; - file.write_all(attr_text.as_bytes()).unwrap(); -} - +#[expect(clippy::needless_pass_by_value)] fn print_docs( providers: Vec>, doc_sections: Vec, @@ -254,7 +236,9 @@ fn print_docs( for f in &providers_with_no_docs { eprintln!(" - {f}"); } - not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + not_impl_err!( + "Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}" + ) } else { Ok(docs) } @@ -303,8 +287,19 @@ impl DocProvider for WindowUDF { } } -#[allow(clippy::borrowed_box)] -#[allow(clippy::ptr_arg)] +impl DocProvider for Arc { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +#[expect(clippy::borrowed_box)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { functions .iter() diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 98804e424b407..0f38988c69405 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -26,22 +26,21 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{ - provider_as_source, DefaultTableSource, MemTable, TableProvider, + DefaultTableSource, MemTable, TableProvider, provider_as_source, }; use crate::error::Result; -use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; +use crate::execution::context::{SessionState, TaskContext}; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - LogicalPlanBuilderOptions, Partitioning, TableType, + Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, TableType, col, ident, }; use crate::physical_plan::{ - collect, collect_partitioned, execute_stream, execute_stream_partitioned, - ExecutionPlan, SendableRecordBatchStream, + ExecutionPlan, SendableRecordBatchStream, collect, collect_partitioned, + execute_stream, execute_stream_partitioned, }; use crate::prelude::SessionContext; -use std::any::Any; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -49,20 +48,20 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, - TableReference, UnnestOptions, + TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, + plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - case, + ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, expr::{Alias, ScalarFunction}, is_null, lit, utils::COUNT_STAR_EXPANSION, - ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ @@ -71,6 +70,7 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::extension_types::DFArrayFormatterFactory; /// Contains options that control how data is /// written out from a DataFrame @@ -78,9 +78,11 @@ pub struct DataFrameWriteOptions { /// Controls how new data should be written to the table, determining whether /// to append, overwrite, or replace existing data. insert_op: InsertOp, - /// Controls if all partitions should be coalesced into a single output file - /// Generally will have slower performance when set to true. - single_file_output: bool, + /// Controls if all partitions should be coalesced into a single output file. + /// - `None`: Use automatic mode (extension-based heuristic) + /// - `Some(true)`: Force single file output at exact path + /// - `Some(false)`: Force directory output with generated filenames + single_file_output: Option, /// Sets which columns should be used for hive-style partitioned writes by name. /// Can be set to empty vec![] for non-partitioned writes. partition_by: Vec, @@ -94,7 +96,7 @@ impl DataFrameWriteOptions { pub fn new() -> Self { DataFrameWriteOptions { insert_op: InsertOp::Append, - single_file_output: false, + single_file_output: None, partition_by: vec![], sort_by: vec![], } @@ -107,8 +109,14 @@ impl DataFrameWriteOptions { } /// Set the single_file_output value to true or false + /// + /// - `true`: Force single file output at the exact path specified + /// - `false`: Force directory output with generated filenames + /// + /// When not called, automatic mode is used (extension-based heuristic). + /// When set to true, an output file will always be created even if the DataFrame is empty. pub fn with_single_file_output(mut self, single_file_output: bool) -> Self { - self.single_file_output = single_file_output; + self.single_file_output = Some(single_file_output); self } @@ -123,6 +131,15 @@ impl DataFrameWriteOptions { self.sort_by = sort_by; self } + + /// Build the options HashMap to pass to CopyTo for sink configuration. + fn build_sink_options(&self) -> HashMap { + let mut options = HashMap::new(); + if let Some(single_file) = self.single_file_output { + options.insert("single_file_output".to_string(), single_file.to_string()); + } + options + } } impl Default for DataFrameWriteOptions { @@ -277,8 +294,11 @@ impl DataFrame { self.session_state.create_logical_expr(sql, df_schema) } - /// Consume the DataFrame and produce a physical plan - pub async fn create_physical_plan(self) -> Result> { + /// Create a physical plan from this DataFrame. + /// + /// The `DataFrame` remains accessible after this call, so you can inspect + /// the plan and still call [`DataFrame::collect`] or other execution methods. + pub async fn create_physical_plan(&self) -> Result> { self.session_state.create_physical_plan(&self.plan).await } @@ -310,11 +330,20 @@ impl DataFrame { pub fn select_columns(self, columns: &[&str]) -> Result { let fields = columns .iter() - .flat_map(|name| { - self.plan + .map(|name| { + let fields = self + .plan .schema() - .qualified_fields_with_unqualified_name(name) + .qualified_fields_with_unqualified_name(name); + if fields.is_empty() { + Err(unqualified_field_not_found(name, self.plan.schema())) + } else { + Ok(fields) + } }) + .collect::, _>>()? + .into_iter() + .flatten() .collect::>(); let expr: Vec = fields .into_iter() @@ -436,15 +465,31 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn drop_columns(self, columns: &[&str]) -> Result { + pub fn drop_columns(self, columns: &[T]) -> Result + where + T: Into + Clone, + { let fields_to_drop = columns .iter() - .flat_map(|name| { - self.plan - .schema() - .qualified_fields_with_unqualified_name(name) + .flat_map(|col| { + let column: Column = col.clone().into(); + match column.relation.as_ref() { + Some(_) => { + // qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)> + vec![self.plan.schema().qualified_field_from_column(&column)] + } + None => { + // qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)> + self.plan + .schema() + .qualified_fields_with_unqualified_name(&column.name) + .into_iter() + .map(Ok) + .collect::>() + } + } }) - .collect::>(); + .collect::, _>>()?; let expr: Vec = self .plan .schema() @@ -470,7 +515,7 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_json("tests/data/unnest.json", NdJsonReadOptions::default()).await?; + /// let df = ctx.read_json("tests/data/unnest.json", JsonReadOptions::default()).await?; /// // expand into multiple columns if it's json array, flatten field name if it's nested structure /// let df = df.unnest_columns(&["b","c","d"])?; /// let expected = vec![ @@ -1474,6 +1519,11 @@ impl DataFrame { let options = self.session_state.config().options().format.clone(); let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?; + let registry = self.session_state.extension_type_registry(); + let formatter_factory = DFArrayFormatterFactory::new(Arc::clone(registry)); + let arrow_options = + arrow_options.with_formatter_factory(Some(&formatter_factory)); + let results = self.collect().await?; Ok( pretty::pretty_format_batches_with_options(&results, &arrow_options)? @@ -1655,7 +1705,7 @@ impl DataFrame { pub fn into_view(self) -> Arc { Arc::new(DataFrameTableProvider { plan: self.plan, - table_type: TableType::Temporary, + table_type: TableType::View, }) } @@ -2013,6 +2063,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2025,7 +2077,7 @@ impl DataFrame { plan, path.into(), file_type, - HashMap::new(), + copy_options, options.partition_by, )? .build()?; @@ -2081,6 +2133,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -2093,7 +2147,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -2232,7 +2286,7 @@ impl DataFrame { .schema() .iter() .map(|(qualifier, field)| { - if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { + if qualifier.eq(&qualifier_rename) && field == field_rename { ( col(Column::from((qualifier, field))) .alias_qualified(qualifier.cloned(), new_name), @@ -2321,6 +2375,10 @@ impl DataFrame { /// Cache DataFrame as a memory table. /// + /// Default behavior could be changed using + /// a [`crate::execution::session_state::CacheFactory`] + /// configured via [`SessionState`]. + /// /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -2335,14 +2393,20 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::new_with_state((*self.session_state).clone()); - // The schema is consistent with the output - let plan = self.clone().create_physical_plan().await?; - let schema = plan.schema(); - let task_ctx = Arc::new(self.task_ctx()); - let partitions = collect_partitioned(plan, task_ctx).await?; - let mem_table = MemTable::try_new(schema, partitions)?; - context.read_table(Arc::new(mem_table)) + if let Some(cache_factory) = self.session_state.cache_factory() { + let new_plan = + cache_factory.create(self.plan, self.session_state.as_ref())?; + Ok(Self::new(*self.session_state, new_plan)) + } else { + let context = SessionContext::new_with_state((*self.session_state).clone()); + // The schema is consistent with the output + let plan = self.create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; + context.read_table(Arc::new(mem_table)) + } } /// Apply an alias to the DataFrame. @@ -2383,6 +2447,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + #[expect(clippy::needless_pass_by_value)] pub fn fill_null( &self, value: ScalarValue, @@ -2393,7 +2458,7 @@ impl DataFrame { .schema() .fields() .iter() - .map(|f| f.as_ref().clone()) + .map(Arc::clone) .collect() } else { self.find_columns(&columns)? @@ -2430,7 +2495,7 @@ impl DataFrame { } // Helper to find columns from names - fn find_columns(&self, names: &[String]) -> Result> { + fn find_columns(&self, names: &[String]) -> Result> { let schema = self.logical_plan().schema(); names .iter() @@ -2443,6 +2508,48 @@ impl DataFrame { .collect() } + /// Find qualified columns for this dataframe from names + /// + /// # Arguments + /// * `names` - Unqualified names to find. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df = ctx.table("first_table").await?; + /// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df2 = ctx.table("second_table").await?; + /// let join_expr = df.find_qualified_columns(&["a"])?.iter() + /// .zip(df2.find_qualified_columns(&["a"])?.iter()) + /// .map(|(col1, col2)| col(*col1).eq(col(*col2))) + /// .collect::>(); + /// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?; + /// # Ok(()) + /// # } + /// ``` + pub fn find_qualified_columns( + &self, + names: &[&str], + ) -> Result, &FieldRef)>> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .qualified_field_from_column(&Column::from_name(*name)) + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + /// Helper for creating DataFrame. /// # Example /// ``` @@ -2540,10 +2647,6 @@ struct DataFrameTableProvider { #[async_trait] impl TableProvider for DataFrameTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn get_logical_plan(&self) -> Option> { Some(Cow::Borrowed(&self.plan)) } diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index cb8a6cf29541b..83ffbb151773b 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -76,6 +76,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -88,7 +90,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -105,7 +107,6 @@ impl DataFrame { #[cfg(test)] mod tests { use std::collections::HashMap; - use std::sync::Arc; use super::super::Result; use super::*; @@ -125,6 +126,19 @@ mod tests { use tempfile::TempDir; use url::Url; + /// Helper to extract a metric value by name from aggregated metrics. + fn metric_usize( + aggregated: &datafusion_physical_expr_common::metrics::MetricsSet, + name: &str, + ) -> usize { + aggregated + .iter() + .find(|m| m.value().name() == name) + .unwrap_or_else(|| panic!("should have {name} metric")) + .value() + .as_usize() + } + #[tokio::test] async fn filter_pushdown_dataframe() -> Result<()> { let ctx = SessionContext::new(); @@ -150,7 +164,7 @@ mod tests { let plan = df.explain(false, false)?.collect().await?; // Filters all the way to Parquet let formatted = pretty::pretty_format_batches(&plan)?.to_string(); - assert!(formatted.contains("FilterExec: id@0 = 1")); + assert!(formatted.contains("FilterExec: id@0 = 1"), "{formatted}"); Ok(()) } @@ -298,8 +312,8 @@ mod tests { // Read encrypted parquet let ctx: SessionContext = SessionContext::new(); - let read_options = - ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + let read_options = ParquetReadOptions::default() + .file_decryption_properties((&decrypt).try_into()?); ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone()) .await?; @@ -324,4 +338,357 @@ mod tests { Ok(()) } + + /// Test FileOutputMode::SingleFile - explicitly request single file output + /// for paths WITHOUT file extensions. This verifies the fix for the regression + /// where extension heuristics ignored the explicit with_single_file_output(true). + #[tokio::test] + async fn test_file_output_mode_single_file() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITHOUT .parquet extension - this is the key scenario + let output_path = tmp_dir.path().join("data_no_ext"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request single file output + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(true), + None, + ) + .await?; + + // Verify: output should be a FILE, not a directory + assert!( + output_path.is_file(), + "Expected single file at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the file is readable as parquet + let file = std::fs::File::open(&output_path)?; + let reader = parquet::file::reader::SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + assert_eq!(metadata.file_metadata().num_rows(), 3); + + Ok(()) + } + + /// Test FileOutputMode::Automatic - uses extension heuristic. + /// Path WITH extension -> single file; path WITHOUT extension -> directory. + #[tokio::test] + async fn test_file_output_mode_automatic() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + // Case 1: Path WITH extension -> should create single file (Automatic mode) + let output_with_ext = tmp_dir.path().join("data.parquet"); + let df = ctx.read_batch(batch.clone())?; + df.write_parquet( + output_with_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_with_ext.is_file(), + "Path with extension should be a single file, got is_file={}, is_dir={}", + output_with_ext.is_file(), + output_with_ext.is_dir() + ); + + // Case 2: Path WITHOUT extension -> should create directory (Automatic mode) + let output_no_ext = tmp_dir.path().join("data_dir"); + let df = ctx.read_batch(batch)?; + df.write_parquet( + output_no_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_no_ext.is_dir(), + "Path without extension should be a directory, got is_file={}, is_dir={}", + output_no_ext.is_file(), + output_no_ext.is_dir() + ); + + Ok(()) + } + + /// Test that ParquetSink exposes rows_written, bytes_written, and + /// elapsed_compute metrics via DataSinkExec. + #[tokio::test] + async fn test_parquet_sink_metrics() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_execution::TaskContext; + + use futures::TryStreamExt; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let output_path = tmp_dir.path().join("metrics_test.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + // Register a table with 100 rows + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Int32, false), + ])); + let ids: Vec = (0..100).collect(); + let vals: Vec = (100..200).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(Int32Array::from(vals)), + ], + )?; + ctx.register_batch("source", batch)?; + + // Create the physical plan for COPY TO + let df = ctx + .sql(&format!( + "COPY source TO '{output_path_str}' STORED AS PARQUET" + )) + .await?; + let plan = df.create_physical_plan().await?; + + // Execute the plan + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx)?; + let _batches: Vec<_> = stream.try_collect().await?; + + // Check metrics on the DataSinkExec (top-level plan) + let metrics = plan + .metrics() + .expect("DataSinkExec should return metrics from ParquetSink"); + let aggregated = metrics.aggregate_by_name(); + + // rows_written should be 100 + assert_eq!( + metric_usize(&aggregated, "rows_written"), + 100, + "expected 100 rows written" + ); + + // bytes_written should be > 0 + let bytes_written = metric_usize(&aggregated, "bytes_written"); + assert!( + bytes_written > 0, + "expected bytes_written > 0, got {bytes_written}" + ); + + // elapsed_compute should be > 0 + let elapsed = metric_usize(&aggregated, "elapsed_compute"); + assert!(elapsed > 0, "expected elapsed_compute > 0"); + + Ok(()) + } + + /// Test that ParquetSink metrics work with single_file_parallelism enabled. + #[tokio::test] + async fn test_parquet_sink_metrics_parallel() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_execution::TaskContext; + + use futures::TryStreamExt; + + let ctx = SessionContext::new(); + ctx.sql("SET datafusion.execution.parquet.allow_single_file_parallelism = true") + .await? + .collect() + .await?; + + let tmp_dir = TempDir::new()?; + let output_path = tmp_dir.path().join("metrics_parallel.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let ids: Vec = (0..50).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(ids))], + )?; + ctx.register_batch("source2", batch)?; + + let df = ctx + .sql(&format!( + "COPY source2 TO '{output_path_str}' STORED AS PARQUET" + )) + .await?; + let plan = df.create_physical_plan().await?; + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx)?; + let _batches: Vec<_> = stream.try_collect().await?; + + let metrics = plan.metrics().expect("DataSinkExec should return metrics"); + let aggregated = metrics.aggregate_by_name(); + + assert_eq!(metric_usize(&aggregated, "rows_written"), 50); + assert!(metric_usize(&aggregated, "bytes_written") > 0); + assert!( + metric_usize(&aggregated, "elapsed_compute") > 0, + "expected elapsed_compute > 0 on parallel path" + ); + + Ok(()) + } + + /// Test that ParquetSink reports a non-zero elapsed_compute on the sequential + /// write path (allow_single_file_parallelism = false), where elapsed_compute + /// is computed as total_write_time - io_time via TimingWriter. + #[tokio::test] + async fn test_parquet_sink_metrics_sequential() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_execution::TaskContext; + + use futures::TryStreamExt; + + let ctx = SessionContext::new(); + ctx.sql("SET datafusion.execution.parquet.allow_single_file_parallelism = false") + .await? + .collect() + .await?; + + let tmp_dir = TempDir::new()?; + let output_path = tmp_dir.path().join("metrics_sequential.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let ids: Vec = (0..50).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(ids))], + )?; + ctx.register_batch("source_seq", batch)?; + + let df = ctx + .sql(&format!( + "COPY source_seq TO '{output_path_str}' STORED AS PARQUET" + )) + .await?; + let plan = df.create_physical_plan().await?; + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx)?; + let _batches: Vec<_> = stream.try_collect().await?; + + let metrics = plan + .metrics() + .expect("DataSinkExec should return metrics from ParquetSink"); + let aggregated = metrics.aggregate_by_name(); + + assert_eq!(metric_usize(&aggregated, "rows_written"), 50); + assert!(metric_usize(&aggregated, "bytes_written") > 0); + assert!( + metric_usize(&aggregated, "elapsed_compute") > 0, + "expected elapsed_compute > 0 on sequential path" + ); + + Ok(()) + } + + /// Test FileOutputMode::Directory - explicitly request directory output + /// even for paths WITH file extensions. + #[tokio::test] + async fn test_file_output_mode_directory() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITH .parquet extension but explicitly requesting directory output + let output_path = tmp_dir.path().join("output.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request directory output (single_file_output = false) + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(false), + None, + ) + .await?; + + // Verify: output should be a DIRECTORY, not a single file + assert!( + output_path.is_dir(), + "Expected directory at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the directory contains parquet file(s) + let entries: Vec<_> = std::fs::read_dir(&output_path)? + .filter_map(|e| e.ok()) + .collect(); + assert!( + !entries.is_empty(), + "Directory should contain at least one file" + ); + + Ok(()) + } + + /// Test that `create_physical_plan` does not consume the `DataFrame`, so + /// callers can inspect (e.g. log) the physical plan and then still call + /// `write_parquet` or any other execution method on the same `DataFrame`. + #[tokio::test] + async fn create_physical_plan_does_not_consume_dataframe() -> Result<()> { + use crate::prelude::CsvReadOptions; + let ctx = SessionContext::new(); + let df = ctx + .read_csv("tests/data/example.csv", CsvReadOptions::new()) + .await?; + + // Obtain the physical plan for inspection without consuming `df`. + let _physical_plan = df.create_physical_plan().await?; + + // `df` is still usable — collect the results. + let batches = df.collect().await?; + assert!(!batches.is_empty()); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index 256a11ba693b5..50ee96da3dff0 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -20,9 +20,9 @@ use std::sync::Arc; +use crate::datasource::TableProvider; use crate::datasource::listing::ListingTableConfigExt; use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; -use crate::datasource::TableProvider; use crate::error::Result; use crate::execution::context::SessionState; diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 8701f96eb3b84..338de76b1353b 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -17,3 +17,96 @@ //! Re-exports the [`datafusion_datasource_arrow::file_format`] module, and contains tests for it. pub use datafusion_datasource_arrow::file_format::*; + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use std::sync::Arc; + + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::Result; + + use crate::execution::options::ArrowReadOptions; + use crate::prelude::SessionContext; + + #[tokio::test] + async fn test_write_empty_arrow_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.arrow", tmp_dir.path().to_string_lossy()); + + ctx.sql(&format!( + "COPY (SELECT CAST(1 AS BIGINT) AS id LIMIT 0) TO '{path}' STORED AS ARROW", + )) + .await? + .collect() + .await?; + + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx.read_arrow(&path, ArrowReadOptions::default()).await?; + let stream = read_df.execute_stream().await?; + + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); + + let results: Vec<_> = stream.collect().await; + let total_rows: usize = results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|b| b.num_rows()) + .sum(); + assert_eq!(total_rows, 0); + + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_arrow_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(Vec::::new())), + Arc::new(StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_batch.arrow", tmp_dir.path().to_string_lossy()); + + ctx.register_batch("empty_table", empty_batch)?; + + ctx.sql(&format!("COPY empty_table TO '{path}' STORED AS ARROW")) + .await? + .collect() + .await?; + + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx.read_arrow(&path, ArrowReadOptions::default()).await?; + let stream = read_df.execute_stream().await?; + + assert_eq!(stream.schema().fields().len(), 2); + assert_eq!(stream.schema().field(0).name(), "id"); + assert_eq!(stream.schema().field(1).name(), "name"); + + let results: Vec<_> = stream.collect().await; + let total_rows: usize = results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|b| b.num_rows()) + .sum(); + assert_eq!(total_rows, 0); + + Ok(()) + } +} diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 3428d08a6ae52..a8b48cc736c92 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -26,20 +26,21 @@ mod tests { use crate::{ datasource::file_format::test_util::scan_format, prelude::SessionContext, }; - use arrow::array::{as_string_array, Array}; + use arrow::array::{Array, as_string_array}; use datafusion_catalog::Session; use datafusion_common::test_util::batches_to_string; use datafusion_common::{ + Result, cast::{ as_binary_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_microsecond_array, }, - test_util, Result, + test_util, }; use datafusion_datasource_avro::AvroFormat; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use futures::StreamExt; use insta::assert_snapshot; @@ -94,7 +95,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); assert_eq!( vec![ @@ -108,7 +109,7 @@ mod tests { "double_col: Float64", "date_string_col: Binary", "string_col: Binary", - "timestamp_col: Timestamp(Microsecond, None)", + "timestamp_col: Timestamp(µs, \"+00:00\")", ], x ); @@ -116,20 +117,20 @@ mod tests { let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); - assert_snapshot!(batches_to_string(&batches),@r###" - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 | - | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00 | - | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00 | - | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00 | - | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00 | - | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 | - | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 | - | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - "###); + assert_snapshot!(batches_to_string(&batches),@r" + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+----------------------+ + | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+----------------------+ + | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00Z | + | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00Z | + | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00Z | + | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00Z | + | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00Z | + | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00Z | + | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00Z | + | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00Z | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+----------------------+ + "); Ok(()) } @@ -245,7 +246,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", + format!("{values:?}") + ); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 52fb8ae904ebf..a068b4f5c0413 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -32,12 +32,12 @@ mod tests { use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_catalog::Session; + use datafusion_common::Result; use datafusion_common::cast::as_string_array; use datafusion_common::config::CsvOptions; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_common::test_util::{arrow_test_data, batches_to_string}; - use datafusion_common::Result; use datafusion_datasource::decoder::{ BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; @@ -45,7 +45,7 @@ mod tests { use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::write::BatchSerializer; use datafusion_expr::{col, lit}; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::array::{ Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, @@ -57,15 +57,16 @@ mod tests { use bytes::Bytes; use chrono::DateTime; use datafusion_common::parsers::CompressionTypeVariant; - use futures::stream::BoxStream; use futures::StreamExt; + use futures::stream::BoxStream; use insta::assert_snapshot; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, - ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, + PutPayload, PutResult, }; use regex::Regex; use rstest::*; @@ -104,10 +105,6 @@ mod tests { unimplemented!() } - async fn get(&self, location: &Path) -> object_store::Result { - self.get_opts(location, GetOptions::default()).await - } - async fn get_opts( &self, location: &Path, @@ -117,6 +114,8 @@ mod tests { let len = bytes.len() as u64; let range = 0..len * self.max_iterations; let arc = self.iterations_detected.clone(); + #[expect(clippy::result_large_err)] + // closure only ever returns Ok; Err type is never constructed let stream = futures::stream::repeat_with(move || { let arc_inner = arc.clone(); *arc_inner.lock().unwrap() += 1; @@ -147,14 +146,6 @@ mod tests { unimplemented!() } - async fn head(&self, _location: &Path) -> object_store::Result { - unimplemented!() - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, _prefix: Option<&Path>, @@ -169,17 +160,21 @@ mod tests { unimplemented!() } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - unimplemented!() - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: object_store::CopyOptions, ) -> object_store::Result<()> { unimplemented!() } + + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() + } } impl VariableStream { @@ -621,15 +616,15 @@ mod tests { .collect() .await?; - assert_snapshot!(batches_to_string(&record_batch), @r###" - +----+------+ - | c2 | c3 | - +----+------+ - | 5 | 36 | - | 5 | -31 | - | 5 | -101 | - +----+------+ - "###); + assert_snapshot!(batches_to_string(&record_batch), @r" + +----+------+ + | c2 | c3 | + +----+------+ + | 5 | 36 | + | 5 | -31 | + | 5 | -101 | + +----+------+ + "); Ok(()) } @@ -706,11 +701,11 @@ mod tests { let re = Regex::new(r"DataSourceExec: file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let n_partitions = match_.as_str().parse::().unwrap(); - return Ok(n_partitions); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let n_partitions = match_.as_str().parse::().unwrap(); + return Ok(n_partitions); } internal_err!("query contains no DataSourceExec") @@ -736,13 +731,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c2) | +--------------+ | 285 | +--------------+ - "###); + "); } assert_eq!(n_partitions, actual_partitions); @@ -775,13 +770,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // Compressed csv won't be scanned in parallel @@ -812,13 +807,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set @@ -843,10 +838,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -868,10 +863,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -944,17 +939,19 @@ mod tests { let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); assert_eq!(files.len(), 1); - assert!(files - .last() - .unwrap() - .as_ref() - .unwrap() - .path() - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".csv.gz")); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv.gz") + ); Ok(()) } @@ -983,17 +980,19 @@ mod tests { let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); assert_eq!(files.len(), 1); - assert!(files - .last() - .unwrap() - .as_ref() - .unwrap() - .path() - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".csv")); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv") + ); Ok(()) } @@ -1032,10 +1031,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -1084,13 +1083,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------------+ - | sum(empty.column_1) | - +---------------------+ - | 10 | - +---------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------------+ + | sum(empty.column_1) | + +---------------------+ + | 10 | + +---------------------+ + ");} assert_eq!(n_partitions, actual_partitions); // Won't get partitioned if all files are empty @@ -1132,13 +1131,13 @@ mod tests { file_size }; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +-----------------------+ | sum(one_col.column_1) | +-----------------------+ | 50 | +-----------------------+ - "###); + "); } assert_eq!(expected_partitions, actual_partitions); @@ -1171,13 +1170,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------+ - | sum_of_5_cols | - +---------------+ - | 15 | - +---------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------+ + | sum_of_5_cols | + +---------------+ + | 15 | + +---------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -1191,7 +1190,9 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1230,7 +1231,8 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1499,7 +1501,7 @@ mod tests { // Create a temp file with a .csv suffix so the reader accepts it let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; // ensures path ends with .csv - // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. + // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. write!(tmp, "a,b,c\n1,2\n3,4,5\n")?; let path = tmp.path().to_str().unwrap().to_string(); @@ -1529,4 +1531,94 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_write_empty_csv_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.csv", tmp_dir.path().to_string_lossy()); + let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; + df.write_csv(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx + .read_csv(&path, CsvReadOptions::default().has_header(true)) + .await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_csv_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int64Array::from(Vec::::new())), + Arc::new(StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_batch.csv", tmp_dir.path().to_string_lossy()); + + // Write empty RecordBatch + let df = ctx.read_batch(empty_batch.clone())?; + df.write_csv(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + // Expected the file to exist + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx + .read_csv(&path, CsvReadOptions::default().has_header(true)) + .await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 2); + assert_eq!(stream.schema().field(0).name(), "id"); + assert_eq!(stream.schema().field(1).name(), "name"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_infer_schema_with_zero_max_records() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let root = format!("{}/csv", arrow_test_data()); + let format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(0); // Set to 0 to disable inference + let exec = scan_format( + &state, + &format, + None, + &root, + "aggregate_test_100.csv", + None, + None, + ) + .await?; + + // related to https://github.com/apache/datafusion/issues/19417 + for f in exec.schema().fields() { + assert_eq!(*f.data_type(), DataType::Utf8); + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 34d3d64f07fb2..5b3e22705620e 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -25,7 +25,7 @@ mod tests { use super::*; use crate::datasource::file_format::test_util::scan_format; - use crate::prelude::{NdJsonReadOptions, SessionConfig, SessionContext}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::RecordBatch; use arrow_schema::Schema; @@ -36,7 +36,7 @@ mod tests { BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; use datafusion_datasource::file_format::FileFormat; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::compute::concat_batches; use arrow::datatypes::{DataType, Field}; @@ -46,12 +46,54 @@ mod tests { use datafusion_common::internal_err; use datafusion_common::stats::Precision; + use crate::execution::options::JsonReadOptions; use datafusion_common::Result; + use datafusion_datasource::file_compression_type::FileCompressionType; use futures::StreamExt; use insta::assert_snapshot; use object_store::local::LocalFileSystem; use regex::Regex; use rstest::rstest; + // ==================== Test Helpers ==================== + + /// Create a temporary JSON file and return (TempDir, path) + fn create_temp_json(content: &str) -> (tempfile::TempDir, String) { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = tmp_dir.path().join("test.json"); + std::fs::write(&path, content).unwrap(); + (tmp_dir, path.to_string_lossy().to_string()) + } + + /// Infer schema from JSON array format file + async fn infer_json_array_schema( + content: &str, + ) -> Result { + let (_tmp_dir, path) = create_temp_json(content); + let session = SessionContext::new(); + let ctx = session.state(); + let store = Arc::new(LocalFileSystem::new()) as _; + let format = JsonFormat::default().with_newline_delimited(false); + format + .infer_schema(&ctx, &store, &[local_unpartitioned_file(&path)]) + .await + } + + /// Register a JSON array table and run a query + async fn query_json_array(content: &str, query: &str) -> Result> { + let (_tmp_dir, path) = create_temp_json(content); + let ctx = SessionContext::new(); + let options = JsonReadOptions::default().newline_delimited(false); + ctx.register_json("test_table", &path, options).await?; + ctx.sql(query).await?.collect().await + } + + /// Register a JSON array table and run a query, return formatted string + async fn query_json_array_str(content: &str, query: &str) -> Result { + let result = query_json_array(content, query).await?; + Ok(batches_to_string(&result)) + } + + // ==================== Existing Tests ==================== #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -187,11 +229,11 @@ mod tests { let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let count = match_.as_str().parse::().unwrap(); - return Ok(count); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); } internal_err!("Query contains no Exec: file_groups") @@ -208,7 +250,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/1.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel", table_path, options) .await?; @@ -218,13 +260,13 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_num_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r###" - +----------------------+ - | sum(json_parallel.a) | - +----------------------+ - | -7 | - +----------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r" + +----------------------+ + | sum(json_parallel.a) | + +----------------------+ + | -7 | + +----------------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -240,7 +282,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/empty.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel_empty", table_path, options) .await?; @@ -249,10 +291,10 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result),@r" + ++ + ++ + "); Ok(()) } @@ -284,15 +326,15 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); - assert_snapshot!(batches_to_string(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - | 11 | 12 | 13 | 14 | 15 | - +----+----+----+----+----+ - "###); + assert_snapshot!(batches_to_string(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + | 11 | 12 | 13 | 14 | 15 | + +----+----+----+----+----+ + "); Ok(()) } @@ -314,7 +356,6 @@ mod tests { .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); let mut all_batches = RecordBatch::new_empty(schema.clone()); - // We get RequiresMoreData after 2 batches because of how json::Decoder works for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { @@ -324,14 +365,14 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); - insta::assert_snapshot!(fmt_batches(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - +----+----+----+----+----+ - "###); + insta::assert_snapshot!(fmt_batches(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + +----+----+----+----+----+ + "); Ok(()) } @@ -349,4 +390,248 @@ mod tests { fn fmt_batches(batches: &[RecordBatch]) -> String { pretty::pretty_format_batches(batches).unwrap().to_string() } + + #[tokio::test] + async fn test_write_empty_json_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("empty_sql.json"); + let path = path.to_string_lossy().to_string(); + let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; + df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + let metadata = std::fs::metadata(&path)?; + assert_eq!(metadata.len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_json_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int64Array::from(Vec::::new())), + Arc::new(arrow::array::StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("empty_batch.json"); + let path = path.to_string_lossy().to_string(); + let df = ctx.read_batch(empty_batch.clone())?; + df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + let metadata = std::fs::metadata(&path)?; + assert_eq!(metadata.len(), 0); + Ok(()) + } + + // ==================== JSON Array Format Tests ==================== + + #[tokio::test] + async fn test_json_array_schema_inference() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"a": 1, "b": 2.0, "c": true}, {"a": 2, "b": 3.5, "c": false}]"#, + ) + .await?; + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + let schema = infer_json_array_schema("[]").await?; + assert_eq!(schema.fields().len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"id": 1, "info": {"name": "Alice", "age": 30}}]"#, + ) + .await?; + + let info_field = schema.field_with_name("info").unwrap(); + assert!(matches!(info_field.data_type(), DataType::Struct(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_type() -> Result<()> { + let schema = + infer_json_array_schema(r#"[{"id": 1, "tags": ["a", "b", "c"]}]"#).await?; + + let tags_field = schema.field_with_name("tags").unwrap(); + assert!(matches!(tags_field.data_type(), DataType::List(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_basic_query() -> Result<()> { + let result = query_json_array_str( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}, {"a": 3, "b": "test"}]"#, + "SELECT a, b FROM test_table ORDER BY a", + ) + .await?; + + assert_snapshot!(result, @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + | 3 | test | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_nulls() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "name": "Alice"}, {"id": 2, "name": null}, {"id": 3, "name": "Charlie"}]"#, + "SELECT id, name FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+ + | id | name | + +----+---------+ + | 1 | Alice | + | 2 | | + | 3 | Charlie | + +----+---------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "values": [10, 20, 30]}, {"id": 2, "values": [40, 50]}]"#, + "SELECT id, unnest(values) as value FROM test_table ORDER BY id, value", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------+ + | id | value | + +----+-------+ + | 1 | 10 | + | 1 | 20 | + | 1 | 30 | + | 2 | 40 | + | 2 | 50 | + +----+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest_struct() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "orders": [{"product": "A", "qty": 2}, {"product": "B", "qty": 3}]}, {"id": 2, "orders": [{"product": "C", "qty": 1}]}]"#, + "SELECT id, unnest(orders)['product'] as product, unnest(orders)['qty'] as qty FROM test_table ORDER BY id, product", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+-----+ + | id | product | qty | + +----+---------+-----+ + | 1 | A | 2 | + | 1 | B | 3 | + | 2 | C | 1 | + +----+---------+-----+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct_access() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "dept": {"name": "Engineering", "head": "Alice"}}, {"id": 2, "dept": {"name": "Sales", "head": "Bob"}}]"#, + "SELECT id, dept['name'] as dept_name, dept['head'] as head FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------------+-------+ + | id | dept_name | head | + +----+-------------+-------+ + | 1 | Engineering | Alice | + | 2 | Sales | Bob | + +----+-------------+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_compression() -> Result<()> { + use flate2::Compression; + use flate2::write::GzEncoder; + use std::io::Write; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("array.json.gz"); + let path = path.to_string_lossy().to_string(); + + let file = std::fs::File::create(&path)?; + let mut encoder = GzEncoder::new(file, Compression::default()); + encoder.write_all( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}]"#.as_bytes(), + )?; + encoder.finish()?; + + let ctx = SessionContext::new(); + let options = JsonReadOptions::default() + .newline_delimited(false) + .file_compression_type(FileCompressionType::GZIP) + .file_extension(".json.gz"); + + ctx.register_json("test_table", &path, options).await?; + let result = ctx + .sql("SELECT a, b FROM test_table ORDER BY a") + .await? + .collect() + .await?; + + assert_snapshot!(batches_to_string(&result), @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_of_structs() -> Result<()> { + let batches = query_json_array( + r#"[{"id": 1, "items": [{"name": "x", "price": 10.5}]}, {"id": 2, "items": []}]"#, + "SELECT id, items FROM test_table ORDER BY id", + ) + .await?; + + assert_eq!(1, batches.len()); + assert_eq!(2, batches[0].num_rows()); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 4881783eeba69..b04238ebc9b37 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -39,8 +39,9 @@ pub(crate) mod test_util { use arrow_schema::SchemaRef; use datafusion_catalog::Session; use datafusion_common::Result; + use datafusion_datasource::TableSchema; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::{file_format::FileFormat, PartitionedFile}; + use datafusion_datasource::{PartitionedFile, file_format::FileFormat}; use datafusion_execution::object_store::ObjectStoreUrl; use std::sync::Arc; @@ -66,31 +67,24 @@ pub(crate) mod test_util { .await? }; + let table_schema = TableSchema::new(file_schema.clone(), vec![]); + let statistics = format .infer_stats(state, &store, file_schema.clone(), &meta) .await?; - let file_groups = vec![vec![PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }] - .into()]; + let file_groups = vec![vec![PartitionedFile::new_from_meta(meta)].into()]; let exec = format .create_physical_plan( state, FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - format.file_source(), + format.file_source(table_schema), ) .with_file_groups(file_groups) .with_statistics(statistics) - .with_projection_indices(projection) + .with_projection_indices(projection)? .with_limit(limit) .build(), ) @@ -131,7 +125,10 @@ mod tests { .write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index e78c5f09553cc..bd0ac36087381 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -25,9 +25,9 @@ use crate::datasource::file_format::avro::AvroFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; +use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::listing::ListingTableUrl; use crate::datasource::{file_format::csv::CsvFormat, listing::ListingOptions}; use crate::error::Result; @@ -442,14 +442,23 @@ impl<'a> AvroReadOptions<'a> { } } -/// Options that control the reading of Line-delimited JSON files (NDJson) +#[deprecated( + since = "53.0.0", + note = "Use `JsonReadOptions` instead. This alias will be removed in a future version." +)] +#[doc = "Deprecated: Use [`JsonReadOptions`] instead."] +pub type NdJsonReadOptions<'a> = JsonReadOptions<'a>; + +/// Options that control the reading of JSON files. +/// +/// Supports both newline-delimited JSON (NDJSON) and JSON array formats. /// /// Note this structure is supplied when a datasource is created and -/// can not not vary from statement to statement. For settings that +/// can not vary from statement to statement. For settings that /// can vary statement to statement see /// [`ConfigOptions`](crate::config::ConfigOptions). #[derive(Clone)] -pub struct NdJsonReadOptions<'a> { +pub struct JsonReadOptions<'a> { /// The data source schema. pub schema: Option<&'a Schema>, /// Max number of rows to read from JSON files for schema inference if needed. Defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD`. @@ -465,9 +474,25 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Whether to read as newline-delimited JSON (default: true). + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, } -impl Default for NdJsonReadOptions<'_> { +impl Default for JsonReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -477,11 +502,12 @@ impl Default for NdJsonReadOptions<'_> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], + newline_delimited: true, } } } -impl<'a> NdJsonReadOptions<'a> { +impl<'a> JsonReadOptions<'a> { /// Specify table_partition_cols for partition pruning pub fn table_partition_cols( mut self, @@ -523,6 +549,32 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Specify how many rows to read for schema inference + pub fn schema_infer_max_records(mut self, schema_infer_max_records: usize) -> Self { + self.schema_infer_max_records = schema_infer_max_records; + self + } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub fn newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } #[async_trait] @@ -648,7 +700,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { } #[async_trait] -impl ReadOptions<'_> for NdJsonReadOptions<'_> { +impl ReadOptions<'_> for JsonReadOptions<'_> { fn to_listing_options( &self, config: &SessionConfig, @@ -657,7 +709,8 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { let file_format = JsonFormat::default() .with_options(table_options.json) .with_schema_infer_max_rec(self.schema_infer_max_records) - .with_file_compression_type(self.file_compression_type.to_owned()); + .with_file_compression_type(self.file_compression_type.to_owned()) + .with_newline_delimited(self.newline_delimited); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 52c5393e10319..6a8f7ab999757 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -107,8 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -120,6 +120,7 @@ mod tests { use arrow::array::RecordBatch; use arrow_schema::Schema; use datafusion_catalog::Session; + use datafusion_common::ScalarValue::Utf8; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_nanosecond_array, @@ -127,43 +128,45 @@ mod tests { use datafusion_common::config::{ParquetOptions, TableParquetOptions}; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_string; - use datafusion_common::ScalarValue::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; + use datafusion_datasource::file_sink_config::{ + FileOutputMode, FileSink, FileSinkConfig, + }; use datafusion_datasource::{ListingTableUrl, PartitionedFile}; use datafusion_datasource_parquet::{ ParquetFormat, ParquetFormatFactory, ParquetSink, }; + use datafusion_execution::TaskContext; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use crate::test_util::bounded_stream; use arrow::array::{ - types::Int32Type, Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, - StringArray, + Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource_parquet::metadata::DFParquetMetadata; - use futures::stream::BoxStream; use futures::StreamExt; + use futures::stream::BoxStream; use insta::assert_snapshot; use object_store::local::LocalFileSystem; - use object_store::ObjectMeta; + use object_store::{CopyOptions, ObjectMeta}; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, - PutMultipartOptions, PutOptions, PutPayload, PutResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; - use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; + use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::file::metadata::{ - KeyValue, ParquetColumnIndex, ParquetMetaData, ParquetOffsetIndex, + KeyValue, PageIndexPolicy, ParquetColumnIndex, ParquetMetaData, + ParquetOffsetIndex, }; use parquet::file::page_index::column_index::ColumnIndexMetaData; use tokio::fs::File; @@ -308,7 +311,7 @@ mod tests { _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn put_multipart_opts( @@ -316,7 +319,7 @@ mod tests { _location: &Path, _opts: PutMultipartOptions, ) -> object_store::Result> { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn get_opts( @@ -328,40 +331,34 @@ mod tests { self.inner.get_opts(location, options).await } - async fn head(&self, _location: &Path) -> object_store::Result { - Err(object_store::Error::NotImplemented) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() } fn list( &self, _prefix: Option<&Path>, ) -> BoxStream<'static, object_store::Result> { - Box::pin(futures::stream::once(async { - Err(object_store::Error::NotImplemented) - })) + unimplemented!() } async fn list_with_delimiter( &self, _prefix: Option<&Path>, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) - } - - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } } @@ -724,7 +721,7 @@ mod tests { // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -770,10 +767,9 @@ mod tests { exec.partition_statistics(None)?.num_rows, Precision::Exact(8) ); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); @@ -816,7 +812,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); let y = x.join("\n"); assert_eq!(expected, y); @@ -842,7 +838,7 @@ mod tests { double_col: Float64\n\ date_string_col: Binary\n\ string_col: Binary\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::No, no_views).await?; let with_views = "id: Int32\n\ @@ -855,7 +851,7 @@ mod tests { double_col: Float64\n\ date_string_col: BinaryView\n\ string_col: BinaryView\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::Yes, with_views).await?; Ok(()) @@ -931,7 +927,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", + format!("{values:?}") + ); Ok(()) } @@ -1101,7 +1100,8 @@ mod tests { let testdata = datafusion_common::test_util::parquet_test_data(); let path = format!("{testdata}/alltypes_tiny_pages.parquet"); let file = File::open(path).await?; - let options = ArrowReaderOptions::new().with_page_index(true); + let options = + ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Required); let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options.clone()) .await? @@ -1204,10 +1204,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } @@ -1233,10 +1233,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } @@ -1364,6 +1364,28 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_write_empty_parquet_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.parquet", tmp_dir.path().to_string_lossy()); + let df = ctx.sql("SELECT CAST(1 AS INT) AS id LIMIT 0").await?; + df.write_parquet(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + // Expected the file to exist + assert!(std::path::Path::new(&path).exists()); + let read_df = ctx.read_parquet(&path, ParquetReadOptions::new()).await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + #[tokio::test] async fn parquet_sink_write_insert_schema_into_metadata() -> Result<()> { // expected kv metadata without schema @@ -1523,6 +1545,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1614,6 +1637,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1704,6 +1728,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index c206566a65941..85dee3f91cffb 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -21,7 +21,8 @@ mod table; pub use datafusion_catalog_listing::helpers; pub use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; -pub use datafusion_datasource::{ - FileRange, ListingTableUrl, PartitionedFile, PartitionedFileStream, -}; +// Keep for backwards compatibility until removed +#[expect(deprecated)] +pub use datafusion_datasource::PartitionedFileStream; +pub use datafusion_datasource::{FileRange, ListingTableUrl, PartitionedFile}; pub use table::ListingTableConfigExt; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3333b70676203..d14ec1f56dce2 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -107,14 +107,16 @@ impl ListingTableConfigExt for ListingTableConfig { #[cfg(test)] mod tests { + #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::listing::table::ListingTableConfigExt; + use crate::execution::options::JsonReadOptions; use crate::prelude::*; use crate::{ datasource::{ - file_format::csv::CsvFormat, file_format::json::JsonFormat, - provider_as_source, DefaultTableSource, MemTable, + DefaultTableSource, MemTable, file_format::csv::CsvFormat, + file_format::json::JsonFormat, provider_as_source, }, execution::options::ArrowReadOptions, test::{ @@ -129,33 +131,26 @@ mod tests { ListingOptions, ListingTable, ListingTableConfig, SchemaSource, }; use datafusion_common::{ - assert_contains, plan_err, + DataFusionError, Result, ScalarValue, assert_contains, stats::Precision, test_util::{batches_to_string, datafusion_test_data}, - ColumnStatistics, DataFusionError, Result, ScalarValue, }; + use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, - }; - use datafusion_datasource::ListingTableUrl; use datafusion_expr::dml::InsertOp; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; - use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::empty::EmptyExec; - use datafusion_physical_plan::{collect, ExecutionPlanProperties}; - use rstest::rstest; + use datafusion_physical_plan::{ExecutionPlanProperties, collect}; use std::collections::HashMap; use std::io::Write; use std::sync::Arc; use tempfile::TempDir; use url::Url; - const DUMMY_NULL_COUNT: Precision = Precision::Exact(42); - /// Creates a test schema with standard field types used in tests fn create_test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -257,7 +252,7 @@ mod tests { ); assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -289,32 +284,36 @@ mod tests { // sort expr, but non column ( vec![vec![col("int_col").add(lit(1)).sort(true, true)]], - Ok(vec![[PhysicalSortExpr { - expr: binary( - physical_col("int_col", &schema).unwrap(), - Operator::Plus, - physical_lit(1), - &schema, - ) - .unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }] - .into()]), + Ok(vec![ + [PhysicalSortExpr { + expr: binary( + physical_col("int_col", &schema).unwrap(), + Operator::Plus, + physical_lit(1), + &schema, + ) + .unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into(), + ]), ), // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![[PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }] - .into()]), + Ok(vec![ + [PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into(), + ]), ), // ok with two columns, different options ( @@ -322,19 +321,21 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![[ - PhysicalSortExpr::new_default( - physical_col("string_col", &schema).unwrap(), - ) - .asc() - .nulls_last(), - PhysicalSortExpr::new_default( - physical_col("int_col", &schema).unwrap(), - ) - .desc() - .nulls_first(), - ] - .into()]), + Ok(vec![ + [ + PhysicalSortExpr::new_default( + physical_col("string_col", &schema).unwrap(), + ) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default( + physical_col("int_col", &schema).unwrap(), + ) + .desc() + .nulls_first(), + ] + .into(), + ]), ), ]; @@ -348,7 +349,7 @@ mod tests { let table = ListingTable::try_new(config.clone()).expect("Creating the table"); let ordering_result = - table.try_create_output_ordering(state.execution_props()); + table.try_create_output_ordering(state.execution_props(), &[]); match (expected_result, ordering_result) { (Ok(expected), Ok(result)) => { @@ -404,7 +405,7 @@ mod tests { .await .expect("Empty execution plan"); - assert!(scan.as_any().is::()); + assert!(scan.is::()); assert_eq!( columns(&scan.schema()), vec!["a".to_owned(), "p1".to_owned()] @@ -453,9 +454,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -488,9 +489,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -538,9 +539,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -731,8 +732,8 @@ mod tests { } #[tokio::test] - async fn test_insert_into_append_new_parquet_files_invalid_session_fails( - ) -> Result<()> { + async fn test_insert_into_append_new_parquet_files_invalid_session_fails() + -> Result<()> { let mut config_map: HashMap = HashMap::new(); config_map.insert( "datafusion.execution.parquet.compression".into(), @@ -746,7 +747,10 @@ mod tests { ) .await .expect_err("Example should fail!"); - assert_eq!(e.strip_backtrace(), "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)"); + assert_eq!( + e.strip_backtrace(), + "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)" + ); Ok(()) } @@ -806,7 +810,7 @@ mod tests { .register_json( "t", tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default() + JsonReadOptions::default() .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -873,13 +877,13 @@ mod tests { let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the records in the table let batches = session_ctx @@ -888,13 +892,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Assert that `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -909,13 +913,13 @@ mod tests { // Again, execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the contents of the table let batches = session_ctx @@ -924,13 +928,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 40 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 40 | + +-------+ + ");} // Assert that another `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -988,15 +992,15 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-----+-----+---+ - | a | b | c | - +-----+-----+---+ - | foo | bar | 1 | - | foo | bar | 2 | - | foo | bar | 3 | - +-----+-----+---+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-----+-----+---+ + | a | b | c | + +-----+-----+---+ + | foo | bar | 1 | + | foo | bar | 2 | + | foo | bar | 3 | + +-----+-----+---+ + ");} Ok(()) } @@ -1307,10 +1311,10 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), 1); + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert_eq!(result.file_groups.len(), 1); - let files = file_list[0].clone(); + let files = result.file_groups[0].clone(); assert_eq!( files @@ -1397,7 +1401,7 @@ mod tests { // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec_enabled.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -1416,7 +1420,9 @@ mod tests { ]; for (format, batch_size, soft_max_rows, expected_files) in test_cases { - println!("Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}"); + println!( + "Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}" + ); let mut config_map = HashMap::new(); config_map.insert( @@ -1449,33 +1455,10 @@ mod tests { } #[tokio::test] - async fn test_statistics_mapping_with_custom_factory() -> Result<()> { - let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - false, - // NullStatsAdapterFactory sets column_statistics null_count to DUMMY_NULL_COUNT - Arc::new(NullStatsAdapterFactory {}), - )?; - - let (groups, stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - - assert_eq!(stats.column_statistics[0].null_count, DUMMY_NULL_COUNT); - for g in groups { - if let Some(s) = g.file_statistics(None) { - assert_eq!(s.column_statistics[0].null_count, DUMMY_NULL_COUNT); - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_statistics_mapping_with_default_factory() -> Result<()> { + async fn test_basic_table_scan() -> Result<()> { let ctx = SessionContext::new(); - // Create a table without providing a custom schema adapter factory - // This should fall back to using DefaultSchemaAdapterFactory + // Test basic table creation and scanning let path = "table/file.json"; register_test_store(&ctx, &[(path, 10)]); @@ -1487,222 +1470,20 @@ mod tests { let config = ListingTableConfig::new(table_path) .with_listing_options(opt) .with_schema(Arc::new(schema)); - // Note: NOT calling .with_schema_adapter_factory() to test default behavior let table = ListingTable::try_new(config)?; - // Verify that no custom schema adapter factory is set - assert!(table.schema_adapter_factory().is_none()); - - // The scan should work correctly with the default schema adapter + // The scan should work correctly let scan_result = table.scan(&ctx.state(), None, &[], None).await; - assert!( - scan_result.is_ok(), - "Scan should succeed with default schema adapter" - ); - - // Verify that the default adapter handles basic schema compatibility - let (groups, _stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert!( - !groups.is_empty(), - "Should list files successfully with default adapter" - ); - - Ok(()) - } - - #[rstest] - #[case(MapSchemaError::TypeIncompatible, "Cannot map incompatible types")] - #[case(MapSchemaError::GeneralFailure, "Schema adapter mapping failed")] - #[case( - MapSchemaError::InvalidProjection, - "Invalid projection in schema mapping" - )] - #[tokio::test] - async fn test_schema_adapter_map_schema_errors( - #[case] error_type: MapSchemaError, - #[case] expected_error_msg: &str, - ) -> Result<()> { - let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - false, - Arc::new(FailingMapSchemaAdapterFactory { error_type }), - )?; - - // The error should bubble up from the scan operation when schema mapping fails - let scan_result = table.scan(&ctx.state(), None, &[], None).await; - - assert!(scan_result.is_err()); - let error_msg = scan_result.unwrap_err().to_string(); - assert!( - error_msg.contains(expected_error_msg), - "Expected error containing '{expected_error_msg}', got: {error_msg}" - ); - - Ok(()) - } - - // Test that errors during file listing also bubble up correctly - #[tokio::test] - async fn test_schema_adapter_error_during_file_listing() -> Result<()> { - let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - true, - Arc::new(FailingMapSchemaAdapterFactory { - error_type: MapSchemaError::TypeIncompatible, - }), - )?; + assert!(scan_result.is_ok(), "Scan should succeed"); - // The error should bubble up from list_files_for_scan when collecting statistics - let list_result = table.list_files_for_scan(&ctx.state(), &[], None).await; - - assert!(list_result.is_err()); - let error_msg = list_result.unwrap_err().to_string(); + // Verify file listing works + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; assert!( - error_msg.contains("Cannot map incompatible types"), - "Expected type incompatibility error during file listing, got: {error_msg}" + !result.file_groups.is_empty(), + "Should list files successfully" ); Ok(()) } - - #[derive(Debug, Copy, Clone)] - enum MapSchemaError { - TypeIncompatible, - GeneralFailure, - InvalidProjection, - } - - #[derive(Debug)] - struct FailingMapSchemaAdapterFactory { - error_type: MapSchemaError, - } - - impl SchemaAdapterFactory for FailingMapSchemaAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(FailingMapSchemaAdapter { - schema: projected_table_schema, - error_type: self.error_type, - }) - } - } - - #[derive(Debug)] - struct FailingMapSchemaAdapter { - schema: SchemaRef, - error_type: MapSchemaError, - } - - impl SchemaAdapter for FailingMapSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - _file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - // Always fail with different error types based on the configured error_type - match self.error_type { - MapSchemaError::TypeIncompatible => { - plan_err!( - "Cannot map incompatible types: Boolean cannot be cast to Utf8" - ) - } - MapSchemaError::GeneralFailure => { - plan_err!("Schema adapter mapping failed due to internal error") - } - MapSchemaError::InvalidProjection => { - plan_err!("Invalid projection in schema mapping: column index out of bounds") - } - } - } - } - - #[derive(Debug)] - struct NullStatsAdapterFactory; - - impl SchemaAdapterFactory for NullStatsAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(NullStatsAdapter { - schema: projected_table_schema, - }) - } - } - - #[derive(Debug)] - struct NullStatsAdapter { - schema: SchemaRef, - } - - impl SchemaAdapter for NullStatsAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let projection = (0..file_schema.fields().len()).collect(); - Ok((Arc::new(NullStatsMapper {}), projection)) - } - } - - #[derive(Debug)] - struct NullStatsMapper; - - impl SchemaMapper for NullStatsMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - Ok(batch) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - Ok(stats - .iter() - .map(|s| { - let mut s = s.clone(); - s.null_count = DUMMY_NULL_COUNT; - s - }) - .collect()) - } - } - - /// Helper function to create a test ListingTable with JSON format and custom schema adapter factory - fn create_test_listing_table_with_json_and_adapter( - ctx: &SessionContext, - collect_stat: bool, - schema_adapter_factory: Arc, - ) -> Result { - let path = "table/file.json"; - register_test_store(ctx, &[(path, 10)]); - - let format = JsonFormat::default(); - let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(collect_stat); - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse("test:///table/")?; - - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(Arc::new(schema)) - .with_schema_adapter_factory(schema_adapter_factory); - - ListingTable::try_new(config) - } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f98297d0e3f7f..a5139346752a9 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -28,8 +28,8 @@ use crate::datasource::listing::{ use crate::execution::context::SessionState; use arrow::datatypes::DataType; -use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema}; -use datafusion_common::{config_datafusion_err, Result}; +use datafusion_common::{Result, config_datafusion_err}; +use datafusion_common::{ToDFSchema, arrow_datafusion_err, plan_err}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -54,7 +54,15 @@ impl TableProviderFactory for ListingTableFactory { cmd: &CreateExternalTable, ) -> Result> { // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? - let session_state = state.as_any().downcast_ref::().unwrap(); + let session_state = + state + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::internal_datafusion_err!( + "ListingTableFactory requires SessionState" + ) + })?; let file_format = session_state .get_file_format_factory(cmd.file_type.as_str()) .ok_or(config_datafusion_err!( @@ -63,7 +71,8 @@ impl TableProviderFactory for ListingTableFactory { ))? .create(session_state, &cmd.options)?; - let mut table_path = ListingTableUrl::parse(&cmd.location)?; + let mut table_path = + ListingTableUrl::parse(&cmd.location)?.with_table_ref(cmd.name.clone()); let file_extension = match table_path.is_collection() { // Setting the extension to be empty instead of allowing the default extension seems // odd, but was done to ensure existing behavior isn't modified. It seems like this @@ -190,6 +199,16 @@ impl TableProviderFactory for ListingTableFactory { .with_definition(cmd.definition.clone()) .with_constraints(cmd.constraints.clone()) .with_column_defaults(cmd.column_defaults.clone()); + + // Pre-warm statistics cache if collect_statistics is enabled + if session_state.config().collect_statistics() { + let filters = &[]; + let limit = None; + if let Err(e) = table.list_files_for_scan(state, filters, limit).await { + log::warn!("Failed to pre-warm statistics cache: {e}"); + } + } + Ok(Arc::new(table)) } } @@ -205,19 +224,24 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { + use super::*; + use crate::{ + datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + test_util::parquet_test_data, + }; + use datafusion_execution::cache::CacheAccessor; + use datafusion_execution::cache::cache_manager::CacheManagerConfig; + use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use glob::Pattern; use std::collections::HashMap; use std::fs; use std::path::PathBuf; - use super::*; - use crate::{ - datasource::file_format::csv::CsvFormat, execution::context::SessionContext, - }; - use datafusion_common::parsers::CompressionTypeVariant; - use datafusion_common::{Constraints, DFSchema, TableReference}; + use datafusion_common::{DFSchema, TableReference}; + use datafusion_expr::registry::ExtensionTypeRegistryRef; #[tokio::test] async fn test_create_using_non_std_file_ext() { @@ -231,27 +255,16 @@ mod tests { let context = SessionContext::new(); let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(HashMap::from([("format.has_header".into(), "true".into())])) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let listing_options = listing_table.options(); assert_eq!(".tbl", listing_options.file_extension); } @@ -272,30 +285,19 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let format = listing_table.options().format.clone(); - let csv_format = format.as_any().downcast_ref::().unwrap(); + let csv_format = format.downcast_ref::().unwrap(); let csv_options = csv_format.options().clone(); assert_eq!(csv_options.schema_infer_max_rec, Some(1000)); let listing_options = listing_table.options(); @@ -317,31 +319,20 @@ mod tests { options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); options.insert("format.compression".into(), "gzip".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); // Verify compression is used let format = listing_table.options().format.clone(); - let csv_format = format.as_any().downcast_ref::().unwrap(); + let csv_format = format.downcast_ref::().unwrap(); let csv_options = csv_format.options().clone(); assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP); @@ -369,27 +360,16 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let listing_options = listing_table.options(); assert_eq!("", listing_options.file_extension); @@ -413,27 +393,15 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: String::from(path.to_str().unwrap()), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + String::from(path.to_str().unwrap()), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let listing_options = listing_table.options(); assert_eq!("", listing_options.file_extension); @@ -453,27 +421,15 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let listing_options = listing_table.options(); let dtype = @@ -494,29 +450,193 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); - let listing_table = table_provider - .as_any() - .downcast_ref::() - .unwrap(); + let listing_table = table_provider.downcast_ref::().unwrap(); let listing_options = listing_table.options(); assert!(listing_options.table_partition_cols.is_empty()); } + + #[tokio::test] + async fn test_statistics_cache_prewarming() { + let factory = ListingTableFactory::new(); + + let location = PathBuf::from(parquet_test_data()) + .join("alltypes_tiny_pages_plain.parquet") + .to_string_lossy() + .to_string(); + + // Test with collect_statistics enabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location.clone(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert!( + file_statistics_cache.len() > 0, + "Statistics cache should be pre-warmed when collect_statistics is enabled" + ); + + // Test with collect_statistics disabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = false; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location, + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert_eq!( + file_statistics_cache.len(), + 0, + "Statistics cache should not be pre-warmed when collect_statistics is disabled" + ); + } + + #[tokio::test] + async fn test_create_with_invalid_session() { + use datafusion_common::config::TableOptions; + use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::any::Any; + use std::collections::HashMap; + + // A mock Session that is NOT SessionState + #[derive(Debug)] + struct MockSession; + + #[async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + "mock_session" + } + fn config(&self) -> &SessionConfig { + unimplemented!() + } + async fn create_physical_plan( + &self, + _logical_plan: &datafusion_expr::LogicalPlan, + ) -> Result> { + unimplemented!() + } + fn create_physical_expr( + &self, + _expr: datafusion_expr::Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + fn scalar_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn higher_order_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn window_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + + fn extension_type_registry(&self) -> &ExtensionTypeRegistryRef { + unreachable!() + } + + fn runtime_env(&self) -> &Arc { + unimplemented!() + } + fn execution_props( + &self, + ) -> &datafusion_expr::execution_props::ExecutionProps { + unimplemented!() + } + fn as_any(&self) -> &dyn Any { + self + } + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } + + let factory = ListingTableFactory::new(); + let mock_session = MockSession; + + let name = TableReference::bare("foo"); + let cmd = CreateExternalTable::builder( + name, + "foo.csv".to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .build(); + + // This should return an error, not panic + let result = factory.create(&mock_session, &cmd).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .strip_backtrace() + .contains("Internal error: ListingTableFactory requires SessionState") + ); + } } diff --git a/datafusion/core/src/datasource/memory_test.rs b/datafusion/core/src/datasource/memory_test.rs index c16837c73b4f1..c7721cafb02ea 100644 --- a/datafusion/core/src/datasource/memory_test.rs +++ b/datafusion/core/src/datasource/memory_test.rs @@ -19,7 +19,7 @@ mod tests { use crate::datasource::MemTable; - use crate::datasource::{provider_as_source, DefaultTableSource}; + use crate::datasource::{DefaultTableSource, provider_as_source}; use crate::physical_plan::collect; use crate::prelude::SessionContext; use arrow::array::{AsArray, Int32Array}; @@ -29,8 +29,8 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::dml::InsertOp; use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::dml::InsertOp; use futures::StreamExt; use std::collections::HashMap; use std::sync::Arc; @@ -329,12 +329,11 @@ mod tests { ); let col = batch.column(0).as_primitive::(); assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len()); - let val = col - .iter() + + col.iter() .next() .expect("had value") - .expect("expected non null"); - val + .expect("expected non null") } // Test inserting a single batch of data into a single partition diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 37b9663111a53..de54078aafef4 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -20,7 +20,6 @@ //! [`ListingTable`]: crate::datasource::listing::ListingTable pub mod dynamic_file; -pub mod empty; pub mod file_format; pub mod listing; pub mod listing_table_factory; @@ -31,7 +30,7 @@ mod view_test; // backwards compatibility pub use self::default_table_source::{ - provider_as_source, source_as_provider, DefaultTableSource, + DefaultTableSource, provider_as_source, source_as_provider, }; pub use self::memory::MemTable; pub use self::view::ViewTable; @@ -39,9 +38,11 @@ pub use crate::catalog::TableProvider; pub use crate::logical_expr::TableType; pub use datafusion_catalog::cte_worktable; pub use datafusion_catalog::default_table_source; +pub use datafusion_catalog::empty; pub use datafusion_catalog::memory; pub use datafusion_catalog::stream; pub use datafusion_catalog::view; +pub use datafusion_datasource::projection; pub use datafusion_datasource::schema_adapter; pub use datafusion_datasource::sink; pub use datafusion_datasource::source; @@ -53,32 +54,35 @@ pub use datafusion_physical_expr::create_ordering; mod tests { use crate::prelude::SessionContext; - use ::object_store::{path::Path, ObjectMeta}; + use ::object_store::{ObjectMeta, path::Path}; use arrow::{ - array::{Int32Array, StringArray}, + array::Int32Array, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; - use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_common::{ + Result, ScalarValue, + test_util::batches_to_sort_string, + tree_node::{Transformed, TransformedResult, TreeNode}, + }; use datafusion_datasource::{ - file::FileSource, - file_scan_config::FileScanConfigBuilder, - schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, - SchemaMapper, - }, - source::DataSourceExec, - PartitionedFile, + PartitionedFile, file_scan_config::FileScanConfigBuilder, source::DataSourceExec, }; use datafusion_datasource_parquet::source::ParquetSource; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_adapter::{ + PhysicalExprAdapter, PhysicalExprAdapterFactory, + }; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::collect; use std::{fs, sync::Arc}; use tempfile::TempDir; + use url::Url; #[tokio::test] - async fn can_override_schema_adapter() { - // Test shows that SchemaAdapter can add a column that doesn't existing in the - // record batches returned from parquet. This can be useful for schema evolution + async fn can_override_physical_expr_adapter() { + // Test shows that PhysicalExprAdapter can add a column that doesn't exist in the + // record batches returned from parquet. This can be useful for schema evolution // where older files may not have all columns. use datafusion_execution::object_store::ObjectStoreUrl; @@ -101,7 +105,8 @@ mod tests { writer.write(&rec_batch).unwrap(); writer.close().unwrap(); - let location = Path::parse(path.to_str().unwrap()).unwrap(); + let url = Url::from_file_path(path.canonicalize().unwrap()).unwrap(); + let location = Path::from_url_path(url.path()).unwrap(); let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); let meta = ObjectMeta { location, @@ -111,29 +116,18 @@ mod tests { version: None, }; - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let f1 = Field::new("id", DataType::Int32, true); let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})) - .unwrap(); - let base_conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema, - source, - ) - .with_file(partitioned_file) - .build(); + let source = Arc::new(ParquetSource::new(Arc::clone(&schema))); + let base_conf = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(partitioned_file) + .with_expr_adapter(Some(Arc::new(TestPhysicalExprAdapterFactory))) + .build(); let parquet_exec = DataSourceExec::from_data_source(base_conf); @@ -141,134 +135,52 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let read = collect(parquet_exec, task_ctx).await.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" + insta::assert_snapshot!(batches_to_sort_string(&read),@r" +----+--------------+ | id | extra_column | +----+--------------+ | 1 | foo | +----+--------------+ - "###); - } - - #[test] - fn default_schema_adapter() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ]); - - // file has a subset of the table schema fields and different type - let file_schema = Schema::new(vec![ - Field::new("c", DataType::Float64, true), // not in table schema - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![1]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - let mapped_batch = mapper.map_batch(file_batch).unwrap(); - - // the mapped batch has the correct schema and the "b" column has been cast to Utf8 - let expected_batch = record_batch!( - ("a", Int32, vec![None, None]), // missing column filled with nulls - ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed - ) - .unwrap(); - assert_eq!(mapped_batch, expected_batch); - } - - #[test] - fn default_schema_adapter_non_nullable_columns() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), // "a"" is declared non nullable - Field::new("b", DataType::Utf8, true), - ]); - let file_schema = Schema::new(vec![ - // since file doesn't have "a" it will be filled with nulls - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![0]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - // Mapping fails because it tries to fill in a non-nullable column with nulls - let err = mapper.map_batch(file_batch).unwrap_err().to_string(); - assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + "); } #[derive(Debug)] - struct TestSchemaAdapterFactory; + struct TestPhysicalExprAdapterFactory; - impl SchemaAdapterFactory for TestSchemaAdapterFactory { + impl PhysicalExprAdapterFactory for TestPhysicalExprAdapterFactory { fn create( &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(TestSchemaAdapter { - table_schema: projected_table_schema, - }) + _logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(TestPhysicalExprAdapter { + physical_file_schema, + })) } } - struct TestSchemaAdapter { - /// Schema for the table - table_schema: SchemaRef, + #[derive(Debug)] + struct TestPhysicalExprAdapter { + physical_file_schema: SchemaRef, } - impl SchemaAdapter for TestSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - Some(file_schema.fields.find(field.name())?.0) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if self.table_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); + impl PhysicalExprAdapter for TestPhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|e| { + if let Some(column) = e.downcast_ref::() { + // If column is "extra_column" and missing from physical schema, inject "foo" + if column.name() == "extra_column" + && self.physical_file_schema.index_of("extra_column").is_err() + { + return Ok(Transformed::yes(Arc::new(Literal::new( + ScalarValue::Utf8(Some("foo".to_string())), + )) + as Arc)); + } } - } - - Ok((Arc::new(TestSchemaMapping {}), projection)) - } - } - - #[derive(Debug)] - struct TestSchemaMapping {} - - impl SchemaMapper for TestSchemaMapping { - fn map_batch( - &self, - batch: RecordBatch, - ) -> datafusion_common::Result { - let f1 = Field::new("id", DataType::Int32, true); - let f2 = Field::new("extra_column", DataType::Utf8, true); - - let schema = Arc::new(Schema::new(vec![f1, f2])); - - let extra_column = Arc::new(StringArray::from(vec!["foo"])); - let mut new_columns = batch.columns().to_vec(); - new_columns.push(extra_column); - - Ok(RecordBatch::try_new(schema, new_columns).unwrap()) - } - - fn map_column_statistics( - &self, - _file_col_statistics: &[datafusion_common::ColumnStatistics], - ) -> datafusion_common::Result> { - unimplemented!() + Ok(Transformed::no(e)) + }) + .data() } } } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 9068c9758179d..2954a47403299 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -31,21 +31,21 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::datatypes::{DataType, Field, SchemaBuilder}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{test_util, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, test_util}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::PartitionedFile; - use datafusion_datasource_avro::source::AvroSource; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_avro::AvroFormat; + use datafusion_datasource_avro::source::AvroSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_plan::ExecutionPlan; use datafusion_datasource::source::DataSourceExec; use futures::StreamExt; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use url::Url; @@ -81,15 +81,11 @@ mod tests { .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await?; - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file(meta.into()) - .with_projection_indices(Some(vec![0, 1, 2])) - .build(); + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(meta.into()) + .with_projection_indices(Some(vec![0, 1, 2]))? + .build(); let source_exec = DataSourceExec::from_data_source(conf); assert_eq!( @@ -109,20 +105,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+ - | id | bool_col | tinyint_col | - +----+----------+-------------+ - | 4 | true | 0 | - | 5 | false | 1 | - | 6 | true | 0 | - | 7 | false | 1 | - | 2 | true | 0 | - | 3 | false | 1 | - | 0 | true | 0 | - | 1 | false | 1 | - +----+----------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + | 3 | false | 1 | + | 0 | true | 0 | + | 1 | false | 1 | + +----+----------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -157,10 +153,10 @@ mod tests { // Include the missing column in the projection let projection = Some(vec![0, 1, 2, actual_schema.fields().len()]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file(meta.into()) - .with_projection_indices(projection) + .with_projection_indices(projection)? .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -182,20 +178,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+-------------+ - | id | bool_col | tinyint_col | missing_col | - +----+----------+-------------+-------------+ - | 4 | true | 0 | | - | 5 | false | 1 | | - | 6 | true | 0 | | - | 7 | false | 1 | | - | 2 | true | 0 | | - | 3 | false | 1 | | - | 0 | true | 0 | | - | 1 | false | 1 | | - +----+----------+-------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+-------------+ + | id | bool_col | tinyint_col | missing_col | + +----+----------+-------------+-------------+ + | 4 | true | 0 | | + | 5 | false | 1 | | + | 6 | true | 0 | | + | 7 | false | 1 | | + | 2 | true | 0 | | + | 3 | false | 1 | | + | 0 | true | 0 | | + | 1 | false | 1 | | + +----+----------+-------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -227,13 +223,16 @@ mod tests { partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let projection = Some(vec![0, 1, file_schema.fields().len(), 2]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let table_schema = TableSchema::new( + file_schema.clone(), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = Arc::new(AvroSource::new(table_schema.clone())); + let conf = FileScanConfigBuilder::new(object_store_url, source) // select specific columns of the files as well as the partitioning // column which is supposed to be the last column in the table schema. - .with_projection_indices(projection) + .with_projection_indices(projection)? .with_file(partitioned_file) - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -256,20 +255,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+------------+-------------+ - | id | bool_col | date | tinyint_col | - +----+----------+------------+-------------+ - | 4 | true | 2021-10-26 | 0 | - | 5 | false | 2021-10-26 | 1 | - | 6 | true | 2021-10-26 | 0 | - | 7 | false | 2021-10-26 | 1 | - | 2 | true | 2021-10-26 | 0 | - | 3 | false | 2021-10-26 | 1 | - | 0 | true | 2021-10-26 | 0 | - | 1 | false | 2021-10-26 | 1 | - +----+----------+------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+------------+-------------+ + | id | bool_col | date | tinyint_col | + +----+----------+------------+-------------+ + | 4 | true | 2021-10-26 | 0 | + | 5 | false | 2021-10-26 | 1 | + | 6 | true | 2021-10-26 | 0 | + | 7 | false | 2021-10-26 | 1 | + | 2 | true | 2021-10-26 | 0 | + | 3 | false | 2021-10-26 | 1 | + | 0 | true | 2021-10-26 | 0 | + | 1 | false | 2021-10-26 | 1 | + +----+----------+------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 4f46a57d8b137..82c47b6c7281c 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -29,18 +29,21 @@ mod tests { use std::io::Write; use std::sync::Arc; + use datafusion_datasource::TableSchema; use datafusion_datasource_csv::CsvFormat; - use object_store::ObjectStore; + use object_store::{ObjectStore, ObjectStoreExt}; + use crate::datasource::file_format::FileFormat; use crate::prelude::CsvReadOptions; use crate::prelude::SessionContext; use crate::test::partitioned_file_groups; + use datafusion_common::config::CsvOptions; use datafusion_common::test_util::arrow_test_data; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, Result}; + use datafusion_common::{Result, assert_batches_eq}; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; + use datafusion_physical_plan::metrics::MetricsSet; #[cfg(feature = "compression")] use datafusion_datasource::file_compression_type::FileCompressionType; @@ -94,32 +97,39 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_file_compression_type(file_compression_type) - .with_newlines_in_values(false) - .with_projection_indices(Some(vec![0, 2, 4])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type) + .with_projection_indices(Some(vec![0, 2, 4]))? + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); @@ -131,17 +141,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+-----+------------+ - | c1 | c3 | c5 | - +----+-----+------------+ - | c | 1 | 2033001162 | - | d | -40 | 706441268 | - | b | 29 | 994303988 | - | a | -85 | 1171968280 | - | b | -82 | 1824882165 | - +----+-----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+-----+------------+ + | c1 | c3 | c5 | + +----+-----+------------+ + | c | 1 | 2033001162 | + | d | -40 | 706441268 | + | b | 29 | 994303988 | + | a | -85 | 1171968280 | + | b | -82 | 1824882165 | + +----+-----+------------+ + ");} Ok(()) } @@ -158,6 +168,8 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -165,26 +177,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_projection_indices(Some(vec![4, 0, 2])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_projection_indices(Some(vec![4, 0, 2]))? + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -194,17 +211,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +------------+----+-----+ - | c5 | c1 | c3 | - +------------+----+-----+ - | 2033001162 | c | 1 | - | 706441268 | d | -40 | - | 994303988 | b | 29 | - | 1171968280 | a | -85 | - | 1824882165 | b | -82 | - +------------+----+-----+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +------------+----+-----+ + | c5 | c1 | c3 | + +------------+----+-----+ + | 2033001162 | c | 1 | + | 706441268 | d | -40 | + | 994303988 | b | 29 | + | 1171968280 | a | -85 | + | 1824882165 | b | -82 | + +------------+----+-----+ + ");} Ok(()) } @@ -221,6 +238,7 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; use futures::StreamExt; let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); @@ -230,26 +248,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(13, csv.schema().fields().len()); @@ -259,17 +282,17 @@ mod tests { assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | - | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | - | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | - | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | - | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | + | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | + | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | + | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | + | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + ");} Ok(()) } @@ -287,32 +310,39 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); assert_eq!(14, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(14, csv.schema().fields().len()); @@ -341,6 +371,7 @@ mod tests { file_compression_type: FileCompressionType, ) -> Result<()> { use datafusion_common::ScalarValue; + use datafusion_datasource::TableSchema; let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -348,12 +379,13 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let mut file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -362,19 +394,25 @@ mod tests { let num_file_schema_fields = file_schema.fields().len(); - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) - // We should be able to project on the partition column - // Which is supposed to be after the file fields - .with_projection_indices(Some(vec![0, num_file_schema_fields])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + // We should be able to project on the partition column + // Which is supposed to be after the file fields + .with_projection_indices(Some(vec![0, num_file_schema_fields]))? + .build(); // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway @@ -388,17 +426,17 @@ mod tests { assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+------------+ - | c1 | date | - +----+------------+ - | c | 2021-10-26 | - | d | 2021-10-26 | - | b | 2021-10-26 | - | a | 2021-10-26 | - | b | 2021-10-26 | - +----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+------------+ + | c1 | date | + +----+------------+ + | c | 2021-10-26 | + | d | 2021-10-26 | + | b | 2021-10-26 | + | a | 2021-10-26 | + | b | 2021-10-26 | + +----+------------+ + ");} let metrics = csv.metrics().expect("doesn't found metrics"); let time_elapsed_processing = get_value(&metrics, "time_elapsed_processing"); @@ -452,26 +490,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), ) .unwrap(); - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); let csv = DataSourceExec::from_data_source(config); let it = csv.execute(0, task_ctx).unwrap(); @@ -527,14 +570,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result), @r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); } #[tokio::test] @@ -556,14 +599,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result),@r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result),@r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); let e = session_ctx .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n'))) @@ -572,7 +615,10 @@ mod tests { .collect() .await .unwrap_err(); - assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2") + assert_eq!( + e.strip_backtrace(), + "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2" + ) } #[tokio::test] @@ -593,22 +639,22 @@ mod tests { .await?; let df = ctx.sql(r#"select * from t1"#).await?.collect().await?; - assert_snapshot!(batches_to_string(&df),@r###" - +------+--------+ - | col1 | col2 | - +------+--------+ - | id0 | value0 | - | id1 | value1 | - | id2 | value2 | - | id3 | value3 | - +------+--------+ - "###); + assert_snapshot!(batches_to_string(&df),@r" + +------+--------+ + | col1 | col2 | + +------+--------+ + | id0 | value0 | + | id1 | value1 | + | id2 | value2 | + | id3 | value3 | + +------+--------+ + "); Ok(()) } #[tokio::test] - async fn test_create_external_table_with_terminator_with_newlines_in_values( - ) -> Result<()> { + async fn test_create_external_table_with_terminator_with_newlines_in_values() + -> Result<()> { let ctx = SessionContext::new(); ctx.sql(r#" CREATE EXTERNAL TABLE t1 ( @@ -658,7 +704,10 @@ mod tests { ) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index f7d5c710bf48a..b70791c7b2390 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -32,11 +32,11 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::execution::SessionState; - use crate::prelude::{CsvReadOptions, NdJsonReadOptions, SessionContext}; + use crate::prelude::{CsvReadOptions, JsonReadOptions, SessionContext}; use crate::test::partitioned_file_groups; + use datafusion_common::Result; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::Result; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource_json::JsonFormat; @@ -51,9 +51,9 @@ mod tests { use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use tempfile::TempDir; use url::Url; @@ -69,11 +69,13 @@ mod tests { let store = state.runtime_env().object_store(&store_url).unwrap(); let filename = "1.json"; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), work_dir, ) @@ -104,11 +106,13 @@ mod tests { ctx.register_object_store(&url, store.clone()); let filename = "1.json"; let tmp_dir = TempDir::new()?; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), tmp_dir.path(), ) @@ -132,22 +136,22 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_options = NdJsonReadOptions::default() + let read_options = JsonReadOptions::default() .file_extension(ext.as_str()) .file_compression_type(file_compression_type.to_owned()); let frame = ctx.read_json(path, read_options).await.unwrap(); let results = frame.collect().await.unwrap(); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r###" - +-----+------------------+---------------+------+ - | a | b | c | d | - +-----+------------------+---------------+------+ - | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | - | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | - | 2 | [2.0, , -6.1] | [false, ] | text | - | | | | | - +-----+------------------+---------------+------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r" + +-----+------------------+---------------+------+ + | a | b | c | d | + +-----+------------------+---------------+------+ + | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | + | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | + | 2 | [2.0, , -6.1] | [false, ] | text | + | | | | | + +-----+------------------+---------------+------+ + ");} Ok(()) } @@ -176,8 +180,8 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -251,8 +255,8 @@ mod tests { let file_schema = Arc::new(builder.finish()); let missing_field_idx = file_schema.fields.len() - 1; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -294,10 +298,11 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_projection_indices(Some(vec![0, 2])) + .unwrap() .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -342,10 +347,10 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) - .with_projection_indices(Some(vec![3, 0, 2])) + .with_projection_indices(Some(vec![3, 0, 2]))? .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -384,7 +389,7 @@ mod tests { let path = format!("{TEST_DATA_BASE}/1.json"); // register json file with the execution context - ctx.register_json("test", path.as_str(), NdJsonReadOptions::default()) + ctx.register_json("test", path.as_str(), JsonReadOptions::default()) .await?; // register a local file system object store for /tmp directory @@ -426,7 +431,7 @@ mod tests { } // register each partition as well as the top level dir - let json_read_option = NdJsonReadOptions::default(); + let json_read_option = JsonReadOptions::default(); ctx.register_json( "part0", &format!("{out_dir}/{part_0_name}"), @@ -494,7 +499,10 @@ mod tests { .write_json(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } @@ -503,7 +511,7 @@ mod tests { async fn read_test_data(schema_infer_max_records: usize) -> Result { let ctx = SessionContext::new(); - let options = NdJsonReadOptions { + let options = JsonReadOptions { schema_infer_max_records, ..Default::default() }; @@ -579,7 +587,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_option = NdJsonReadOptions::default() + let read_option = JsonReadOptions::default() .file_compression_type(file_compression_type) .file_extension(ext.as_str()); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 1ac292e260fdf..8e4855afa66bb 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -43,146 +43,11 @@ pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; pub use datafusion_datasource::file_groups::FileGroupPartitioner; pub use datafusion_datasource::file_scan_config::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, - FileScanConfigBuilder, + FileScanConfig, FileScanConfigBuilder, wrap_partition_type_in_dict, + wrap_partition_value_in_dict, }; pub use datafusion_datasource::file_sink_config::*; pub use datafusion_datasource::file_stream::{ - FileOpenFuture, FileOpener, FileStream, OnError, + FileOpenFuture, FileOpener, FileStream, FileStreamBuilder, OnError, }; - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::{ - cast::AsArray, - types::{Float32Type, Float64Type, UInt32Type}, - BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, - StringArray, UInt64Array, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SchemaRef; - - use crate::datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapterFactory, - }; - - #[test] - fn schema_mapping_map_batch() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt32, true), - Field::new("c3", DataType::Float64, true), - ])); - - let adapter = DefaultSchemaAdapterFactory - .create(table_schema.clone(), table_schema.clone()); - - let file_schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt64, true), - Field::new("c3", DataType::Float32, true), - ]); - - let (mapping, _) = adapter.map_schema(&file_schema).expect("map schema failed"); - - let c1 = StringArray::from(vec!["hello", "world"]); - let c2 = UInt64Array::from(vec![9_u64, 5_u64]); - let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], - ) - .unwrap(); - - let mapped_batch = mapping.map_batch(batch).unwrap(); - - assert_eq!(mapped_batch.schema(), table_schema); - assert_eq!(mapped_batch.num_columns(), 3); - assert_eq!(mapped_batch.num_rows(), 2); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c3 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "hello"); - assert_eq!(c1.value(1), "world"); - assert_eq!(c2.value(0), 9_u32); - assert_eq!(c2.value(1), 5_u32); - assert_eq!(c3.value(0), 2.0_f64); - assert_eq!(c3.value(1), 7.0_f64); - } - - #[test] - fn schema_adapter_map_schema_with_projection() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c0", DataType::Utf8, true), - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::Int32, true), - Field::new("c4", DataType::Float32, true), - ])); - - let file_schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("c1", DataType::Boolean, true), - Field::new("c2", DataType::Float32, true), - Field::new("c3", DataType::Binary, true), - Field::new("c4", DataType::Int64, true), - ]); - - let indices = vec![1, 2, 4]; - let schema = SchemaRef::from(table_schema.project(&indices).unwrap()); - let adapter = DefaultSchemaAdapterFactory.create(schema, table_schema.clone()); - let (mapping, projection) = adapter.map_schema(&file_schema).unwrap(); - - let id = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - let c1 = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); - let c2 = Float32Array::from(vec![Some(2.0_f32), Some(7.0_f32), Some(3.0_f32)]); - let c3 = BinaryArray::from_opt_vec(vec![ - Some(b"hallo"), - Some(b"danke"), - Some(b"super"), - ]); - let c4 = Int64Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![ - Arc::new(id), - Arc::new(c1), - Arc::new(c2), - Arc::new(c3), - Arc::new(c4), - ], - ) - .unwrap(); - let rows_num = batch.num_rows(); - let projected = batch.project(&projection).unwrap(); - let mapped_batch = mapping.map_batch(projected).unwrap(); - - assert_eq!( - mapped_batch.schema(), - Arc::new(table_schema.project(&indices).unwrap()) - ); - assert_eq!(mapped_batch.num_columns(), indices.len()); - assert_eq!(mapped_batch.num_rows(), rows_num); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c4 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "true"); - assert_eq!(c1.value(1), "false"); - assert_eq!(c1.value(2), "true"); - - assert_eq!(c2.value(0), 2.0_f64); - assert_eq!(c2.value(1), 7.0_f64); - assert_eq!(c2.value(2), 3.0_f64); - - assert_eq!(c4.value(0), 1.0_f32); - assert_eq!(c4.value(1), 2.0_f32); - assert_eq!(c4.value(2), 3.0_f32); - } -} diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 0ffb252a66052..dd8c20628b43e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,29 +38,29 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, - StringViewArray, StructArray, TimestampNanosecondArray, + ArrayRef, AsArray, Date64Array, DictionaryArray, Int8Array, Int32Array, + Int64Array, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, }; - use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder, UInt16Type}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SchemaRef, TimeUnit}; use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; - use datafusion_common::{assert_contains, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, assert_contains}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::file::FileSource; - use datafusion_datasource::{FileRange, PartitionedFile, TableSchema}; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ DefaultParquetFileReaderFactory, ParquetFileReaderFactory, ParquetFormat, }; use datafusion_execution::object_store::ObjectStoreUrl; - use datafusion_expr::{col, lit, when, Expr}; + use datafusion_expr::{Expr, col, lit, when}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::analyze::AnalyzeExec; use datafusion_physical_plan::collect; @@ -161,7 +161,7 @@ mod tests { .as_ref() .map(|p| logical2physical(p, &table_schema)); - let mut source = ParquetSource::default(); + let mut source = ParquetSource::new(table_schema); if let Some(predicate) = predicate { source = source.with_predicate(predicate); } @@ -186,23 +186,20 @@ mod tests { source = source.with_bloom_filter_on_read(false); } - source.with_schema(TableSchema::new(Arc::clone(&table_schema), vec![])) + Arc::new(source) } fn build_parquet_exec( &self, - file_schema: SchemaRef, file_group: FileGroup, source: Arc, ) -> Arc { - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .with_projection_indices(self.projection.clone()) - .build(); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .with_projection_indices(self.projection.clone()) + .unwrap() + .build(); DataSourceExec::from_data_source(base_config) } @@ -231,19 +228,16 @@ mod tests { // build a ParquetExec to return the results let parquet_source = self.build_file_source(Arc::clone(table_schema)); - let parquet_exec = self.build_parquet_exec( - Arc::clone(table_schema), - file_group.clone(), - Arc::clone(&parquet_source), - ); + let parquet_exec = + self.build_parquet_exec(file_group.clone(), Arc::clone(&parquet_source)); let analyze_exec = Arc::new(AnalyzeExec::new( false, false, - vec![MetricType::SUMMARY, MetricType::DEV], + vec![MetricType::Summary, MetricType::Dev], + None, // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( - Arc::clone(table_schema), file_group.clone(), self.build_file_source(Arc::clone(table_schema)), ), @@ -313,7 +307,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); @@ -344,13 +338,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -371,7 +365,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -402,13 +396,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -433,7 +427,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1, c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -464,13 +458,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -495,7 +489,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c3.clone(), c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -526,13 +520,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -575,13 +569,13 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 10 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -605,7 +599,7 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ @@ -613,7 +607,7 @@ mod tests { | 4 | | 40 | | 5 | | 50 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -642,7 +636,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+----+----+ | c1 | c2 | c3 | +-----+----+----+ @@ -656,7 +650,7 @@ mod tests { | bar | | | | bar | | | +-----+----+----+ - "###); + "); } #[tokio::test] @@ -757,18 +751,18 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+----+ - | c1 | c3 | c2 | - +-----+----+----+ - | | | | - | | 10 | 1 | - | | 20 | | - | | 20 | 2 | - | Foo | 10 | | - | bar | | | - +-----+----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+----+ + | c1 | c3 | c2 | + +-----+----+----+ + | | | | + | | 10 | 1 | + | | 20 | | + | | 20 | 2 | + | Foo | 10 | | + | bar | | | + +-----+----+----+ + "); } #[tokio::test] @@ -789,14 +783,14 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+----+ | c1 | c3 | c2 | +----+----+----+ | | 10 | 1 | | | 20 | 2 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 4); @@ -832,7 +826,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+-----+ | c1 | c4 | +-----+-----+ @@ -843,7 +837,7 @@ mod tests { | bar | | | bar | | +-----+-----+ - "###); + "); } #[tokio::test] @@ -1002,6 +996,7 @@ mod tests { assert_eq!(read, 1, "Expected 1 rows to match the predicate"); assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "page_index_pages_pruned"), 1); assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); // If we filter with a value that is completely out of the range of the data // we prune at the row group level. @@ -1056,18 +1051,18 @@ mod tests { // In a real query where this predicate was pushed down from a filter stage instead of created directly in the `DataSourceExec`, // the filter stage would be preserved as a separate execution plan stage so the actual query results would be as expected. - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | | - | | | - | | 1 | - | | 2 | - | Foo | | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | | + | | | + | | 1 | + | | 2 | + | Foo | | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1092,13 +1087,13 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+ | c1 | c2 | +----+----+ | | 1 | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 5); @@ -1152,7 +1147,7 @@ mod tests { .round_trip(vec![batch1, batch2, batch3, batch4]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +------+----+ | c1 | c2 | +------+----+ @@ -1169,16 +1164,22 @@ mod tests { | Foo2 | | | Foo3 | | +------+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // There are 4 rows pruned in each of batch2, batch3, and // batch4 for a total of 12. batch1 had no pruning as c2 was // filled in as null - let (page_index_pruned, page_index_matched) = + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 12); - assert_eq!(page_index_matched, 6); + assert_eq!(page_index_rows_pruned, 12); + assert_eq!(page_index_rows_matched, 6); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 6); + assert_eq!(page_index_pages_matched, 3); } #[tokio::test] @@ -1201,14 +1202,14 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1231,15 +1232,15 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | 2 | - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | 2 | + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1264,7 +1265,7 @@ mod tests { ("c3", c3.clone()), ]); - // batch2: c3(int8), c2(int64), c1(string), c4(string) + // batch2: c3(date64), c2(int64), c1(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); let table_schema = Schema::new(vec![ @@ -1278,8 +1279,10 @@ mod tests { .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; - assert_contains!(read.unwrap_err().to_string(), - "Cannot cast file schema field c3 of type Date64 to table schema field of type Int8"); + assert_contains!( + read.unwrap_err().to_string(), + "Cannot cast column 'c3' from 'Date64' (physical data type) to 'Int8' (logical data type)" + ); } #[tokio::test] @@ -1329,7 +1332,7 @@ mod tests { async fn parquet_exec_with_int96_from_spark() -> Result<()> { // arrow-rs relies on the chrono library to convert between timestamps and strings, so // instead compare as Int64. The underlying type should be a PrimitiveArray of Int64 - // anyway, so this should be a zero-copy non-modifying cast at the SchemaAdapter. + // anyway, so this should be a zero-copy non-modifying cast. let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); let testdata = datafusion_common::test_util::parquet_test_data(); @@ -1532,14 +1535,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { - PartitionedFile { - object_meta: meta.clone(), - partition_values: vec![], - range: Some(FileRange { start, end }), - statistics: None, - extensions: None, - metadata_size_hint: None, - } + PartitionedFile::new_from_meta(meta.clone()).with_range(start, end) } async fn assert_parquet_read( @@ -1550,8 +1546,7 @@ mod tests { ) -> Result<()> { let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_groups(file_groups) .build(); @@ -1622,21 +1617,15 @@ mod tests { .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![ + let partitioned_file = PartitionedFile::new_from_meta(meta) + .with_partition_values(vec![ ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), Box::new(ScalarValue::from("26")), ), - ], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + ]); let expected_schema = Schema::new(vec![ Field::new("id", DataType::Int32, true), @@ -1653,23 +1642,27 @@ mod tests { ), ]); - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) - .with_file(partitioned_file) - // file has 10 cols so index 12 should be month and 13 should be day - .with_projection_indices(Some(vec![0, 1, 2, 12, 13])) - .with_table_partition_cols(vec![ - Field::new("year", DataType::Utf8, false), - Field::new("month", DataType::UInt8, false), - Field::new( + let table_schema = TableSchema::new( + Arc::clone(&schema), + vec![ + Arc::new(Field::new("year", DataType::Utf8, false)), + Arc::new(Field::new("month", DataType::UInt8, false)), + Arc::new(Field::new( "day", DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), ), false, - ), - ]) + )), + ], + ); + let source = Arc::new(ParquetSource::new(table_schema.clone())); + let config = FileScanConfigBuilder::new(object_store_url, source) + .with_file(partitioned_file) + // file has 10 cols so index 12 should be month and 13 should be day + .with_projection_indices(Some(vec![0, 1, 2, 12, 13])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -1684,20 +1677,20 @@ mod tests { let batch = results.next().await.unwrap()?; assert_eq!(batch.schema().as_ref(), &expected_schema); - assert_snapshot!(batches_to_string(&[batch]),@r###" - +----+----------+-------------+-------+-----+ - | id | bool_col | tinyint_col | month | day | - +----+----------+-------------+-------+-----+ - | 4 | true | 0 | 10 | 26 | - | 5 | false | 1 | 10 | 26 | - | 6 | true | 0 | 10 | 26 | - | 7 | false | 1 | 10 | 26 | - | 2 | true | 0 | 10 | 26 | - | 3 | false | 1 | 10 | 26 | - | 0 | true | 0 | 10 | 26 | - | 1 | false | 1 | 10 | 26 | - +----+----------+-------------+-------+-----+ - "###); + assert_snapshot!(batches_to_string(&[batch]),@r" + +----+----------+-------------+-------+-----+ + | id | bool_col | tinyint_col | month | day | + +----+----------+-------------+-------+-----+ + | 4 | true | 0 | 10 | 26 | + | 5 | false | 1 | 10 | 26 | + | 6 | true | 0 | 10 | 26 | + | 7 | false | 1 | 10 | 26 | + | 2 | true | 0 | 10 | 26 | + | 3 | false | 1 | 10 | 26 | + | 0 | true | 0 | 10 | 26 | + | 1 | false | 1 | 10 | 26 | + +----+----------+-------------+-------+-----+ + "); let batch = results.next().await; assert!(batch.is_none()); @@ -1711,28 +1704,20 @@ mod tests { let state = session_ctx.state(); let location = Path::from_filesystem_path(".") .unwrap() - .child("invalid.parquet"); + .join("invalid.parquet"); - let partitioned_file = PartitionedFile { - object_meta: ObjectMeta { - location, - last_modified: Utc.timestamp_nanos(0), - size: 1337, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(ObjectMeta { + location, + last_modified: Utc.timestamp_nanos(0), + size: 1337, + e_tag: None, + version: None, + }); let file_schema = Arc::new(Schema::empty()); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file(partitioned_file) .build(); @@ -1757,6 +1742,7 @@ mod tests { Some(3), Some(4), Some(5), + Some(6), // last page with only one row ])); let batch1 = create_batch(vec![("int", c1.clone())]); @@ -1765,27 +1751,53 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter) .with_page_index_predicate() - .round_trip(vec![batch1]) + .round_trip(vec![batch1.clone()]) .await; let metrics = rt.parquet_exec.metrics().unwrap(); - assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r###" - +-----+ - | int | - +-----+ - | 4 | - | 5 | - +-----+ - "###); - let (page_index_pruned, page_index_matched) = + assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r" + +-----+ + | int | + +-----+ + | 4 | + | 5 | + +-----+ + "); + let (page_index_rows_pruned, page_index_rows_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); - assert_eq!(page_index_pruned, 4); - assert_eq!(page_index_matched, 2); + assert_eq!(page_index_rows_pruned, 5); + assert_eq!(page_index_rows_matched, 2); assert!( get_value(&metrics, "page_index_eval_time") > 0, "no eval time in metrics: {metrics:#?}" ); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); + + // test with a filter that matches the page with one row + let filter = col("int").eq(lit(6_i32)); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_page_index_predicate() + .round_trip(vec![batch1]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 6); + assert_eq!(page_index_rows_matched, 1); + + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); } /// Returns a string array with contents: @@ -1823,14 +1835,14 @@ mod tests { let metrics = rt.parquet_exec.metrics().unwrap(); // assert the batches and some metrics - assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r###" - +-----+ - | c1 | - +-----+ - | Foo | - | zzz | - +-----+ - "###); + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+ + | c1 | + +-----+ + | Foo | + | zzz | + +-----+ + "); // pushdown predicates have eliminated all 4 bar rows and the // null row for 5 rows total @@ -1879,6 +1891,100 @@ mod tests { assert_contains!(&explain, "projection=[c1]"); } + #[tokio::test] + async fn parquet_exec_metrics_with_multiple_predicates() { + // Test that metrics are correctly calculated when multiple predicates + // are pushed down (connected with AND). This ensures we don't double-count + // rows when multiple predicates filter the data sequentially. + + // Create a batch with two columns: c1 (string) and c2 (int32) + // Total: 10 rows + let c1: ArrayRef = Arc::new(StringArray::from(vec![ + Some("foo"), // 0 - passes c1 filter, fails c2 filter (5 <= 10) + Some("bar"), // 1 - fails c1 filter + Some("bar"), // 2 - fails c1 filter + Some("baz"), // 3 - passes both filters (20 > 10) + Some("foo"), // 4 - passes both filters (12 > 10) + Some("bar"), // 5 - fails c1 filter + Some("baz"), // 6 - passes both filters (25 > 10) + Some("foo"), // 7 - passes c1 filter, fails c2 filter (7 <= 10) + Some("bar"), // 8 - fails c1 filter + Some("qux"), // 9 - passes both filters (30 > 10) + ])); + + let c2: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(5), + Some(15), + Some(8), + Some(20), + Some(12), + Some(9), + Some(25), + Some(7), + Some(18), + Some(30), + ])); + + let batch = create_batch(vec![("c1", c1), ("c2", c2)]); + + // Create filter: c1 != 'bar' AND c2 > 10 + // + // First predicate (c1 != 'bar'): + // - Rows passing: 0, 3, 4, 6, 7, 9 (6 rows) + // - Rows pruned: 1, 2, 5, 8 (4 rows) + // + // Second predicate (c2 > 10) on remaining 6 rows: + // - Rows passing: 3, 4, 6, 9 (4 rows with c2 = 20, 12, 25, 30) + // - Rows pruned: 0, 7 (2 rows with c2 = 5, 7) + // + // Expected final metrics: + // - pushdown_rows_matched: 4 (final result) + // - pushdown_rows_pruned: 4 + 2 = 6 (cumulative) + // - Total: 4 + 6 = 10 + + let filter = col("c1").not_eq(lit("bar")).and(col("c2").gt(lit(10))); + + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .round_trip(vec![batch]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + // Verify the result rows + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | baz | 20 | + | foo | 12 | + | baz | 25 | + | qux | 30 | + +-----+----+ + "); + + // Verify metrics - this is the key test + let pushdown_rows_matched = get_value(&metrics, "pushdown_rows_matched"); + let pushdown_rows_pruned = get_value(&metrics, "pushdown_rows_pruned"); + + assert_eq!( + pushdown_rows_matched, 4, + "Expected 4 rows to pass both predicates" + ); + assert_eq!( + pushdown_rows_pruned, 6, + "Expected 6 rows to be pruned (4 by first predicate + 2 by second predicate)" + ); + + // The sum should equal the total number of rows + assert_eq!( + pushdown_rows_matched + pushdown_rows_pruned, + 10, + "matched + pruned should equal total rows" + ); + } + #[tokio::test] async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) @@ -2119,13 +2225,13 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); Ok(()) } @@ -2148,13 +2254,55 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); + Ok(()) + } + + /// Tests that constant dictionary columns (where min == max in statistics) + /// are correctly handled. This reproduced a bug where the constant value + /// from statistics had type Utf8 but the schema expected Dictionary. + #[tokio::test] + async fn test_constant_dictionary_column_parquet() -> Result<()> { + let tmp_dir = TempDir::new()?; + let path = tmp_dir.path().to_str().unwrap().to_string() + "/test.parquet"; + + // Write parquet with dictionary column where all values are the same + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + false, + )])); + let status: DictionaryArray = + vec!["active", "active"].into_iter().collect(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(status)])?; + let file = File::create(&path)?; + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::Page) + .build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Query the constant dictionary column + let ctx = SessionContext::new(); + ctx.register_parquet("t", &path, ParquetReadOptions::default()) + .await?; + let result = ctx.sql("SELECT status FROM t").await?.collect().await?; + + insta::assert_snapshot!(batches_to_string(&result),@r" + +--------+ + | status | + +--------+ + | active | + | active | + +--------+ + "); Ok(()) } @@ -2279,42 +2427,28 @@ mod tests { let size_hint_calls = reader_factory.metadata_size_hint_calls.clone(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(Arc::clone(&schema)) .with_parquet_file_reader_factory(reader_factory) .with_metadata_size_hint(456), ); - let config = FileScanConfigBuilder::new(store_url, schema, source) + let config = FileScanConfigBuilder::new(store_url, source) .with_file( - PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_1), - last_modified: Utc::now(), - size: total_size_1, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - .with_metadata_size_hint(123), - ) - .with_file(PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_2), + PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_1), last_modified: Utc::now(), - size: total_size_2, + size: total_size_1, e_tag: None, version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) + }) + .with_metadata_size_hint(123), + ) + .with_file(PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_2), + last_modified: Utc::now(), + size: total_size_2, + e_tag: None, + version: None, + })) .build(); let exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/src/datasource/view_test.rs b/datafusion/core/src/datasource/view_test.rs index 85ad9ff664ade..35418d6dea632 100644 --- a/datafusion/core/src/datasource/view_test.rs +++ b/datafusion/core/src/datasource/view_test.rs @@ -46,13 +46,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---+ | b | +---+ | 2 | +---+ - "###); + "); Ok(()) } @@ -96,14 +96,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 1 | 2 | 3 | | 4 | 5 | 6 | +---------+---------+---------+ - "###); + "); let view_sql = "CREATE VIEW replace_xyz AS SELECT * REPLACE (column1*2 as column1) FROM xyz"; @@ -115,14 +115,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 2 | 2 | 3 | | 8 | 5 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -146,14 +146,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+ | column1_alias | +---------------+ | 1 | | 4 | +---------------+ - "###); + "); Ok(()) } @@ -177,14 +177,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+---------------+ | column2_alias | column1_alias | +---------------+---------------+ | 2 | 1 | | 5 | 4 | +---------------+---------------+ - "###); + "); Ok(()) } @@ -213,14 +213,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } @@ -249,13 +249,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 4 | +---------+ - "###); + "); Ok(()) } @@ -287,14 +287,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column2 | column1 | column3 | +---------+---------+---------+ | 2 | 1 | 3 | | 5 | 4 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -358,7 +358,10 @@ mod tests { .to_string(); assert!(formatted.contains("DataSourceExec: ")); assert!(formatted.contains("file_type=parquet")); - assert!(formatted.contains("projection=[bool_col, int_col], limit=10")); + assert!( + formatted.contains("projection=[bool_col, int_col], limit=10"), + "{formatted}" + ); Ok(()) } @@ -442,14 +445,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index e9d799400863d..f7df2ad7a1cd6 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use super::super::options::ReadOptions; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; +use crate::execution::options::JsonReadOptions; use datafusion_common::TableReference; use datafusion_datasource_json::source::plan_to_json; use std::sync::Arc; -use super::super::options::{NdJsonReadOptions, ReadOptions}; -use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; - impl SessionContext { /// Creates a [`DataFrame`] for reading an JSON data source. /// @@ -32,7 +32,7 @@ impl SessionContext { pub async fn read_json( &self, table_paths: P, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result { self._read_type(table_paths, options).await } @@ -43,7 +43,7 @@ impl SessionContext { &self, table_ref: impl Into, table_path: impl AsRef, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 687779787ab50..87170f595f413 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::{Arc, Weak}; +use std::time::Duration; use super::options::ReadOptions; use crate::datasource::dynamic_file::DynamicListTableFactory; @@ -33,20 +34,20 @@ use crate::{ datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }, - datasource::{provider_as_source, MemTable, ViewTable}, + datasource::{MemTable, ViewTable, provider_as_source}, error::Result, execution::{ + FunctionRegistry, options::ArrowReadOptions, runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, - FunctionRegistry, }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, - TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, ResetVariable, + SetVariable, TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -58,32 +59,44 @@ pub use crate::execution::session_state::SessionState; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::MemoryCatalogProvider; +use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::{ DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory, }; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigField, ConfigOptions}; use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::{ + DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, +}; +pub use datafusion_execution::TaskContext; +use datafusion_execution::cache::cache_manager::{ + DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT, DEFAULT_LIST_FILES_CACHE_TTL, + DEFAULT_METADATA_CACHE_LIMIT, }; pub use datafusion_execution::config::SessionConfig; +use datafusion_execution::disk_manager::{ + DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder, +}; use datafusion_execution::registry::SerializerRegistry; -pub use datafusion_execution::TaskContext; +use datafusion_expr::HigherOrderUDF; pub use datafusion_expr::execution_props::ExecutionProps; +#[cfg(feature = "sql")] +use datafusion_expr::planner::RelationPlanner; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ + Expr, UserDefinedLogicalNode, WindowUDF, expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, }; use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; -use datafusion_optimizer::Analyzer; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; +use datafusion_optimizer::{Analyzer, OptimizerContext}; use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; use datafusion_session::SessionStore; @@ -242,7 +255,7 @@ where /// let state = SessionStateBuilder::new() /// .with_config(config) /// .with_runtime_env(runtime_env) -/// // include support for built in functions and configurations +/// // include support for built-in functions and configurations /// .with_default_features() /// .build(); /// @@ -308,7 +321,7 @@ impl SessionContext { let schema = cat .schema(schema_name.as_str()) .ok_or_else(|| internal_datafusion_err!("Schema not found!"))?; - let lister = schema.as_any().downcast_ref::(); + let lister = schema.downcast_ref::(); if let Some(lister) = lister { lister.refresh(&self.state()).await?; } @@ -476,6 +489,11 @@ impl SessionContext { self.state.write().append_optimizer_rule(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.state.write().remove_optimizer_rule(name) + } + /// Adds an analyzer rule to the end of the existing rules. /// /// See [`SessionState`] for more control of when the rule is applied. @@ -513,19 +531,14 @@ impl SessionContext { self.runtime_env().deregister_object_store(url) } - /// Registers the [`RecordBatch`] as the specified table name + /// Registers the given [`RecordBatch`] as the specified table reference. pub fn register_batch( &self, - table_name: &str, + table_ref: impl Into, batch: RecordBatch, ) -> Result>> { let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - self.register_table( - TableReference::Bare { - table: table_name.into(), - }, - Arc::new(table), - ) + self.register_table(table_ref, Arc::new(table)) } /// Return the [RuntimeEnv] used to run queries with this `SessionContext` @@ -678,7 +691,7 @@ impl SessionContext { match ddl { DdlStatement::CreateExternalTable(cmd) => { (Box::pin(async move { self.create_external_table(&cmd).await }) - as std::pin::Pin + Send>>) + as std::pin::Pin + Send>>) .await } DdlStatement::CreateMemoryTable(cmd) => { @@ -709,7 +722,12 @@ impl SessionContext { } // TODO what about the other statements (like TransactionStart and TransactionEnd) LogicalPlan::Statement(Statement::SetVariable(stmt)) => { - self.set_variable(stmt).await + self.set_variable(stmt).await?; + self.return_empty_dataframe() + } + LogicalPlan::Statement(Statement::ResetVariable(stmt)) => { + self.reset_variable(stmt).await?; + self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Prepare(Prepare { name, @@ -727,12 +745,19 @@ impl SessionContext { ); } } - // Store the unoptimized plan into the session state. Although storing the - // optimized plan or the physical plan would be more efficient, doing so is - // not currently feasible. This is because `now()` would be optimized to a - // constant value, causing each EXECUTE to yield the same result, which is - // incorrect behavior. - self.state.write().store_prepared(name, fields, input)?; + // Optimize the plan without evaluating expressions like now() + let optimizer_context = OptimizerContext::new_with_config_options( + Arc::clone(self.state().config().options()), + ) + .without_query_execution_start_time(); + let plan = self.state().optimizer().optimize( + Arc::unwrap_or_clone(input), + &optimizer_context, + |_1, _2| {}, + )?; + self.state + .write() + .store_prepared(name, fields, Arc::new(plan))?; self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Execute(execute)) => { @@ -774,7 +799,7 @@ impl SessionContext { /// * [`SessionState::create_physical_expr`] for a lower level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs pub fn create_physical_expr( &self, expr: Expr, @@ -926,13 +951,13 @@ impl SessionContext { match (or_replace, view) { (true, Ok(_)) => { self.deregister_table(name.clone())?; - let input = Self::apply_type_coercion(input.as_ref().clone())?; + let input = Self::apply_type_coercion(Arc::unwrap_or_clone(input))?; let table = Arc::new(ViewTable::new(input, definition)); self.register_table(name, table)?; self.return_empty_dataframe() } (_, Err(_)) => { - let input = Self::apply_type_coercion(input.as_ref().clone())?; + let input = Self::apply_type_coercion(Arc::unwrap_or_clone(input))?; let table = Arc::new(ViewTable::new(input, definition)); self.register_table(name, table)?; self.return_empty_dataframe() @@ -1052,22 +1077,22 @@ impl SessionContext { } else if allow_missing { return self.return_empty_dataframe(); } else { - return self.schema_doesnt_exist_err(name); + return self.schema_doesnt_exist_err(&name); } }; let dereg = catalog.deregister_schema(name.schema_name(), cascade)?; match (dereg, allow_missing) { (None, true) => self.return_empty_dataframe(), - (None, false) => self.schema_doesnt_exist_err(name), + (None, false) => self.schema_doesnt_exist_err(&name), (Some(_), _) => self.return_empty_dataframe(), } } - fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result { - exec_err!("Schema '{schemaref}' doesn't exist.") + fn schema_doesnt_exist_err(&self, schema_ref: &SchemaReference) -> Result { + exec_err!("Schema '{schema_ref}' doesn't exist.") } - async fn set_variable(&self, stmt: SetVariable) -> Result { + async fn set_variable(&self, stmt: SetVariable) -> Result<()> { let SetVariable { variable, value, .. } = stmt; @@ -1097,11 +1122,37 @@ impl SessionContext { for udf in udfs_to_update { state.register_udf(udf)?; } + } - drop(state); + Ok(()) + } + + async fn reset_variable(&self, stmt: ResetVariable) -> Result<()> { + let variable = stmt.variable; + if variable.starts_with("datafusion.runtime.") { + return self.reset_runtime_variable(&variable); } - self.return_empty_dataframe() + let mut state = self.state.write(); + state.config_mut().options_mut().reset(&variable)?; + + // Refresh UDFs to ensure configuration-dependent behavior updates + let config_options = state.config().options(); + let udfs_to_update: Vec<_> = state + .scalar_functions() + .values() + .filter_map(|udf| { + udf.inner() + .with_updated_config(config_options) + .map(Arc::new) + }) + .collect(); + + for udf in udfs_to_update { + state.register_udf(udf)?; + } + + Ok(()) } fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { @@ -1112,18 +1163,65 @@ impl SessionContext { let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); builder = match key { "memory_limit" => { - let memory_limit = Self::parse_memory_limit(value)?; + let memory_limit = Self::parse_capacity_limit(variable, value)?; builder.with_memory_limit(memory_limit, 1.0) } "max_temp_directory_size" => { - let directory_size = Self::parse_memory_limit(value)?; + let directory_size = Self::parse_capacity_limit(variable, value)?; builder.with_max_temp_directory_size(directory_size as u64) } "temp_directory" => builder.with_temp_file_path(value), "metadata_cache_limit" => { - let limit = Self::parse_memory_limit(value)?; + let limit = Self::parse_capacity_limit(variable, value)?; builder.with_metadata_cache_limit(limit) } + "list_files_cache_limit" => { + let limit = Self::parse_capacity_limit(variable, value)?; + builder.with_object_list_cache_limit(limit) + } + "list_files_cache_ttl" => { + let duration = Self::parse_duration(variable, value)?; + builder.with_object_list_cache_ttl(Some(duration)) + } + _ => return plan_err!("Unknown runtime configuration: {variable}"), + // Remember to update `reset_runtime_variable()` when adding new options + }; + + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + + Ok(()) + } + + fn reset_runtime_variable(&self, variable: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + let mut state = self.state.write(); + + let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + match key { + "memory_limit" => { + builder.memory_pool = None; + } + "max_temp_directory_size" => { + builder = + builder.with_max_temp_directory_size(DEFAULT_MAX_TEMP_DIRECTORY_SIZE); + } + "temp_directory" => { + builder.disk_manager_builder = Some(DiskManagerBuilder::default()); + } + "metadata_cache_limit" => { + builder = builder.with_metadata_cache_limit(DEFAULT_METADATA_CACHE_LIMIT); + } + "list_files_cache_limit" => { + builder = builder + .with_object_list_cache_limit(DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT); + } + "list_files_cache_ttl" => { + builder = + builder.with_object_list_cache_ttl(DEFAULT_LIST_FILES_CACHE_TTL); + } _ => return plan_err!("Unknown runtime configuration: {variable}"), }; @@ -1150,11 +1248,23 @@ impl SessionContext { /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize /// ); /// ``` + #[deprecated( + since = "53.0.0", + note = "please use `parse_capacity_limit` function instead." + )] pub fn parse_memory_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!("Empty limit value found!")); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number.parse().map_err(|_| { plan_datafusion_err!("Failed to parse number from memory limit '{limit}'") })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number" + )); + } match unit { "K" => Ok((number * 1024.0) as usize), @@ -1164,6 +1274,111 @@ impl SessionContext { } } + /// Parse capacity limit from string to number of bytes by allowing units: K, M and G. + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1M").unwrap(), + /// 1024 * 1024 + /// ); + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1.5G").unwrap(), + /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize + /// ); + /// ``` + pub fn parse_capacity_limit(config_name: &str, limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!( + "Empty limit value found for '{config_name}'" + )); + } + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + plan_datafusion_err!( + "Failed to parse number from '{config_name}', limit '{limit}'" + ) + })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number for '{config_name}'" + )); + } + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => plan_err!( + "Unsupported unit '{unit}' in '{config_name}', limit '{limit}'. \ + Unit must be one of: 'K', 'M', 'G'" + ), + } + } + + fn parse_duration(config_name: &str, duration: &str) -> Result { + if duration.trim().is_empty() { + return Err(plan_datafusion_err!( + "Duration should not be empty or blank for '{config_name}'" + )); + } + + let mut minutes = None; + let mut seconds = None; + + for duration in duration.split_inclusive(&['m', 's']) { + let (number, unit) = duration.split_at(duration.len() - 1); + let number: u64 = number.parse().map_err(|_| { + plan_datafusion_err!("Failed to parse number from duration '{duration}' for '{config_name}'") + })?; + + match unit { + "m" if minutes.is_none() && seconds.is_none() => minutes = Some(number), + "s" if seconds.is_none() => seconds = Some(number), + other => plan_err!( + "Invalid duration unit: '{other}'. The unit must be either 'm' (minutes), or 's' (seconds), and be in the correct order for '{config_name}'" + )?, + } + } + + let secs = Self::check_overflow(config_name, minutes, 60, seconds)?; + let duration = Duration::from_secs(secs); + + if duration.is_zero() { + return plan_err!( + "Duration must be greater than 0 seconds for '{config_name}'" + ); + } + + Ok(duration) + } + + fn check_overflow( + config_name: &str, + mins: Option, + multiplier: u64, + secs: Option, + ) -> Result { + let first_part_of_secs = mins.unwrap_or_default().checked_mul(multiplier); + if first_part_of_secs.is_none() { + plan_err!( + "Duration has overflowed allowed maximum limit due to 'mins * {multiplier}' when setting '{config_name}'" + )? + } + let second_part_of_secs = first_part_of_secs + .unwrap() + .checked_add(secs.unwrap_or_default()); + if second_part_of_secs.is_none() { + plan_err!( + "Duration has overflowed allowed maximum limit due to 'mins * {multiplier} + secs' when setting '{config_name}'" + )? + } + Ok(second_part_of_secs.unwrap()) + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1190,20 +1405,24 @@ impl SessionContext { let table = table_ref.table().to_owned(); let maybe_schema = { let state = self.state.read(); - let resolved = state.resolve_table_ref(table_ref); + let resolved = state.resolve_table_ref(table_ref.clone()); state .catalog_list() .catalog(&resolved.catalog) .and_then(|c| c.schema(&resolved.schema)) }; - if let Some(schema) = maybe_schema { - if let Some(table_provider) = schema.table(&table).await? { - if table_provider.table_type() == table_type { - schema.deregister_table(&table)?; - return Ok(true); - } + if let Some(schema) = maybe_schema + && let Some(table_provider) = schema.table(&table).await? + && table_provider.table_type() == table_type + { + schema.deregister_table(&table)?; + if table_type == TableType::Base + && let Some(lfc) = self.runtime_env().cache_manager.get_list_files_cache() + { + lfc.drop_table_entries(&Some(table_ref))?; } + return Ok(true); } Ok(false) @@ -1219,7 +1438,7 @@ impl SessionContext { _ => { return Err(DataFusionError::Configuration( "Function factory has not been configured".to_string(), - )) + )); } } }; @@ -1269,14 +1488,24 @@ impl SessionContext { exec_datafusion_err!("Prepared statement '{}' does not exist", name) })?; + let state = self.state.read(); + let context = SimplifyContext::builder() + .with_schema(Arc::clone(prepared.plan.schema())) + .with_config_options(Arc::clone(state.config_options())) + .with_query_execution_start_time( + state.execution_props().query_execution_start_time, + ) + .build(); + let simplifier = ExprSimplifier::new(context); + // Only allow literals as parameters for now. let mut params: Vec = parameters .into_iter() - .map(|e| match e { + .map(|e| match simplifier.simplify(e)? { Expr::Literal(scalar, metadata) => { Ok(ScalarAndMetadata::new(scalar, metadata)) } - _ => not_impl_err!("Unsupported parameter type: {}", e), + e => not_impl_err!("Unsupported parameter type: {e}"), }) .collect::>()?; @@ -1359,6 +1588,18 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL table-factor planning. + /// + /// Planners are invoked in reverse registration order, allowing newer + /// planners to take precedence over existing ones. + pub fn register_relation_planner( + &self, + planner: Arc, + ) -> Result<()> { + self.state.write().register_relation_planner(planner) + } + /// Deregisters a UDF within this context. pub fn deregister_udf(&self, name: &str) { self.state.write().deregister_udf(name).ok(); @@ -1544,15 +1785,14 @@ impl SessionContext { /// SQL statements executed against this context. pub async fn register_arrow( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: ArrowReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); - self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), @@ -1738,6 +1978,10 @@ impl FunctionRegistry for SessionContext { self.state.read().udf(name) } + fn higher_order_function(&self, name: &str) -> Result> { + self.state.read().higher_order_function(name) + } + fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } @@ -1750,6 +1994,13 @@ impl FunctionRegistry for SessionContext { self.state.write().register_udf(udf) } + fn register_higher_order_function( + &mut self, + function: Arc, + ) -> Result>> { + self.state.write().register_higher_order_function(function) + } + fn register_udaf( &mut self, udaf: Arc, @@ -1779,6 +2030,10 @@ impl FunctionRegistry for SessionContext { self.state.write().register_expr_planner(expr_planner) } + fn higher_order_function_names(&self) -> HashSet { + self.state.read().higher_order_function_names() + } + fn udafs(&self) -> HashSet { self.state.read().udafs() } @@ -1788,6 +2043,12 @@ impl FunctionRegistry for SessionContext { } } +impl datafusion_execution::TaskContextProvider for SessionContext { + fn task_ctx(&self) -> Arc { + SessionContext::task_ctx(self) + } +} + /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { @@ -1831,7 +2092,7 @@ pub trait QueryPlanner: Debug { /// because the implementation and requirements vary widely. Please see /// [function_factory example] for a reference implementation. /// -/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/function_factory.rs +/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/function_factory.rs /// /// # Examples of syntax that can be supported /// @@ -1998,7 +2259,9 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use arrow::datatypes::{DataType, TimeUnit}; + use arrow_schema::FieldRef; use datafusion_common::DataFusionError; + use datafusion_common::datatype::DataTypeExt; use std::error::Error; use std::path::PathBuf; @@ -2023,7 +2286,7 @@ mod tests { // configure with same memory / disk manager let memory_pool = ctx1.runtime_env().memory_pool.clone(); - let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + let reservation = MemoryConsumer::new("test").register(&memory_pool); reservation.grow(100); let disk_manager = ctx1.runtime_env().disk_manager.clone(); @@ -2515,7 +2778,7 @@ mod tests { struct MyTypePlanner {} impl TypePlanner for MyTypePlanner { - fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + fn plan_type_field(&self, sql_type: &ast::DataType) -> Result> { match sql_type { ast::DataType::Datetime(precision) => { let precision = match precision { @@ -2525,10 +2788,213 @@ mod tests { None | Some(9) => TimeUnit::Nanosecond, _ => unreachable!(), }; - Ok(Some(DataType::Timestamp(precision, None))) + Ok(Some( + DataType::Timestamp(precision, None).into_nullable_field_ref(), + )) } _ => Ok(None), } } } + + #[tokio::test] + async fn remove_optimizer_rule() -> Result<()> { + let get_optimizer_rules = |ctx: &SessionContext| { + ctx.state() + .optimizer() + .rules + .iter() + .map(|r| r.name().to_owned()) + .collect::>() + }; + + let ctx = SessionContext::new(); + assert!(get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // default plan + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(2) AS Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + assert!(ctx.remove_optimizer_rule("simplify_expressions")); + assert!(!get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // plan without the simplify_expressions rule + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + // attempting to remove a non-existing rule returns false + assert!(!ctx.remove_optimizer_rule("simplify_expressions")); + + Ok(()) + } + + #[test] + fn test_parse_duration() { + const LIST_FILES_CACHE_TTL: &str = "datafusion.runtime.list_files_cache_ttl"; + + // Valid durations + for (duration, want) in [ + ("1s", Duration::from_secs(1)), + ("1m", Duration::from_secs(60)), + ("1m0s", Duration::from_secs(60)), + ("1m1s", Duration::from_secs(61)), + ] { + let have = + SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration).unwrap(); + assert_eq!(want, have); + } + + // Invalid durations + for duration in [ + "0s", "0m", "1s0m", "1s1m", "XYZ", "1h", "XYZm2s", "", " ", "-1m", "1m 1s", + "1m1s ", " 1m1s", + ] { + let have = SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration); + assert!(have.is_err()); + assert!( + have.unwrap_err() + .message() + .to_string() + .contains(LIST_FILES_CACHE_TTL) + ); + } + } + + #[test] + fn test_parse_duration_with_overflow_check() { + const LIST_FILES_CACHE_TTL: &str = "datafusion.runtime.list_files_cache_ttl"; + + // Valid durations which are close to max allowed limit + for (duration, want) in [ + ( + "18446744073709551615s", + Duration::from_secs(18446744073709551615), + ), + ( + "307445734561825860m", + Duration::from_secs(307445734561825860 * 60), + ), + ( + "307445734561825860m10s", + Duration::from_secs(307445734561825860 * 60 + 10), + ), + ( + "1m18446744073709551555s", + Duration::from_secs(60 + 18446744073709551555), + ), + ] { + let have = + SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration).unwrap(); + assert_eq!(want, have); + } + + // Invalid durations which overflow max allowed limit + for (duration, error_message_prefix) in [ + ( + "18446744073709551616s", + "Failed to parse number from duration", + ), + ( + "307445734561825861m", + "Duration has overflowed allowed maximum limit due to", + ), + ( + "307445734561825860m60s", + "Duration has overflowed allowed maximum limit due to", + ), + ( + "1m18446744073709551556s", + "Duration has overflowed allowed maximum limit due to", + ), + ] { + let have = SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration); + assert!(have.is_err()); + let error_message = have.unwrap_err().message().to_string(); + assert!( + error_message.contains(error_message_prefix) + && error_message.contains(LIST_FILES_CACHE_TTL) + ); + } + } + + #[test] + fn test_parse_memory_limit() { + // Valid memory_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid memory_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit); + assert!(have.is_err()); + } + } + + #[test] + fn test_parse_capacity_limit() { + const MEMORY_LIMIT: &str = "datafusion.runtime.memory_limit"; + + // Valid capacity_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid capacity_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit); + assert!(have.is_err()); + assert!(have.unwrap_err().to_string().contains(MEMORY_LIMIT)); + } + } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 731f7e59ecfaf..823dc946ea732 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -113,7 +113,7 @@ mod tests { }; use datafusion_execution::config::SessionConfig; - use tempfile::{tempdir, TempDir}; + use tempfile::{TempDir, tempdir}; #[tokio::test] async fn read_with_glob_path() -> Result<()> { @@ -355,7 +355,9 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") + format!( + "Execution error: File path '{expected_path}' does not match the expected extension '.parquet'" + ) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..de5e6b97c1af9 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -27,14 +27,14 @@ use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory}; use crate::datasource::file_format::FileFormatFactory; #[cfg(feature = "sql")] use crate::datasource::provider_as_source; -use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; +use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use arrow_schema::{DataType, FieldRef}; +use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::information_schema::{ - InformationSchemaProvider, INFORMATION_SCHEMA, + INFORMATION_SCHEMA, InformationSchemaProvider, }; -use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; #[cfg(feature = "sql")] @@ -43,23 +43,26 @@ use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, - ResolvedTableReference, TableReference, + DFSchema, DataFusionError, ResolvedTableReference, TableReference, config_err, + exec_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; +#[cfg(feature = "sql")] +use datafusion_expr::TableSource; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] -use datafusion_expr::planner::TypePlanner; -use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; -use datafusion_expr::simplify::SimplifyInfo; -#[cfg(feature = "sql")] -use datafusion_expr::TableSource; +use datafusion_expr::planner::{RelationPlanner, TypePlanner}; +use datafusion_expr::registry::{ + ExtensionTypeRegistryRef, FunctionRegistry, MemoryExtensionTypeRegistry, + SerializerRegistry, +}; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Explain, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -67,9 +70,11 @@ use datafusion_optimizer::{ }; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; +use datafusion_physical_optimizer::PhysicalOptimizerContext; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::operator_statistics::StatisticsRegistry; use datafusion_session::Session; #[cfg(feature = "sql")] use datafusion_sql::{ @@ -139,6 +144,8 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + #[cfg(feature = "sql")] + relation_planners: Vec>, /// Provides support for customizing the SQL type planning #[cfg(feature = "sql")] type_planner: Option>, @@ -154,10 +161,14 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, + /// Higher order functions that are registered with the context + higher_order_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context window_functions: HashMap>, + /// Extension types registry for extensions. + extension_types: ExtensionTypeRegistryRef, /// Deserializer registry for extensions. serializer_registry: Arc, /// Holds registered external FileFormat implementations @@ -185,11 +196,28 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + cache_factory: Option>, + /// Optional statistics registry for pluggable statistics providers. + /// + /// When set, physical optimizer rules can use this registry to obtain + /// enhanced statistics (e.g., NDV overrides, histograms) beyond what + /// is available from `ExecutionPlan::partition_statistics()`. + statistics_registry: Option, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. prepared_plans: HashMap>, } +impl PhysicalOptimizerContext for SessionState { + fn config_options(&self) -> &ConfigOptions { + self.config_options() + } + + fn statistics_registry(&self) -> Option<&StatisticsRegistry> { + self.statistics_registry.as_ref() + } +} + impl Debug for SessionState { /// Prefer having short fields at the top and long vector fields near the end /// Group fields by @@ -206,8 +234,12 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); + #[cfg(feature = "sql")] + let ret = ret.field("relation_planners", &self.relation_planners); + #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -217,6 +249,7 @@ impl Debug for SessionState { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("higher_order_functions", &self.higher_order_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .field("prepared_plans", &self.prepared_plans) @@ -253,6 +286,10 @@ impl Session for SessionState { &self.scalar_functions } + fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions + } + fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions } @@ -261,6 +298,10 @@ impl Session for SessionState { &self.window_functions } + fn extension_type_registry(&self) -> &ExtensionTypeRegistryRef { + &self.extension_types + } + fn runtime_env(&self) -> &Arc { self.runtime_env() } @@ -345,6 +386,13 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub(crate) fn remove_optimizer_rule(&mut self, name: &str) -> bool { + let original_len = self.optimizer.rules.len(); + self.optimizer.rules.retain(|r| r.name() != name); + self.optimizer.rules.len() < original_len + } + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn set_function_factory(&mut self, function_factory: Arc) { self.function_factory = Some(function_factory); @@ -355,6 +403,16 @@ impl SessionState { self.function_factory.as_ref() } + /// Register a [`CacheFactory`] for custom caching strategy + pub fn set_cache_factory(&mut self, cache_factory: Arc) { + self.cache_factory = Some(cache_factory); + } + + /// Get the cache factory + pub fn cache_factory(&self) -> Option<&Arc> { + self.cache_factory.as_ref() + } + /// Get the table factories pub fn table_factories(&self) -> &HashMap> { &self.table_factories @@ -480,10 +538,10 @@ impl SessionState { let resolved = self.resolve_table_ref(reference); if let Entry::Vacant(v) = provider.tables.entry(resolved) { let resolved = v.key(); - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { - if let Some(table) = schema.table(&resolved.table).await? { - v.insert(provider_as_source(table)); - } + if let Ok(schema) = self.schema_for_ref(resolved.clone()) + && let Some(table) = schema.table(&resolved.table).await? + { + v.insert(provider_as_source(table)); } } } @@ -547,6 +605,16 @@ impl SessionState { let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?; + self.create_logical_expr_from_sql_expr(sql_expr, df_schema) + } + + /// Creates a datafusion style AST [`Expr`] from a SQL expression. + #[cfg(feature = "sql")] + pub fn create_logical_expr_from_sql_expr( + &self, + sql_expr: SQLExprWithAlias, + df_schema: &DFSchema, + ) -> datafusion_common::Result { let provider = SessionContextProvider { state: self, tables: HashMap::new(), @@ -571,6 +639,24 @@ impl SessionState { &self.expr_planners } + #[cfg(feature = "sql")] + /// Returns the registered relation planners in priority order. + pub fn relation_planners(&self) -> &[Arc] { + &self.relation_planners + } + + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL relation planning. + /// + /// Newly registered planners are given higher priority than existing ones. + pub fn register_relation_planner( + &mut self, + planner: Arc, + ) -> datafusion_common::Result<()> { + self.relation_planners.insert(0, planner); + Ok(()) + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -685,20 +771,26 @@ impl SessionState { /// * [`create_physical_expr`] for a lower-level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs /// [`SessionContext::create_physical_expr`]: crate::execution::context::SessionContext::create_physical_expr pub fn create_physical_expr( &self, expr: Expr, df_schema: &DFSchema, ) -> datafusion_common::Result> { - let simplifier = - ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema)); + let config_options = self.config_options(); + let simplify_context = SimplifyContext::builder() + .with_schema(Arc::new(df_schema.clone())) + .with_config_options(Arc::clone(config_options)) + .with_query_execution_start_time( + self.execution_props().query_execution_start_time, + ) + .build(); + let simplifier = ExprSimplifier::new(simplify_context); // apply type coercion here to ensure types match let mut expr = simplifier.coerce(expr, df_schema)?; // rewrite Exprs to functions if necessary - let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? @@ -752,6 +844,14 @@ impl SessionState { self.config.options() } + /// Returns the statistics registry if one is configured. + /// + /// The registry provides pluggable statistics providers for enhanced + /// cardinality estimation (e.g., NDV overrides, histograms). + pub fn statistics_registry(&self) -> Option<&StatisticsRegistry> { + self.statistics_registry.as_ref() + } + /// Mark the start of the execution pub fn mark_start_execution(&mut self) { let config = Arc::clone(self.config.options()); @@ -788,10 +888,18 @@ impl SessionState { overwrite: bool, ) -> Result<(), DataFusionError> { let ext = file_format.get_ext().to_lowercase(); - match (self.file_formats.entry(ext.clone()), overwrite){ - (Entry::Vacant(e), _) => {e.insert(file_format);}, - (Entry::Occupied(mut e), true) => {e.insert(file_format);}, - (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + match (self.file_formats.entry(ext.clone()), overwrite) { + (Entry::Vacant(e), _) => { + e.insert(file_format); + } + (Entry::Occupied(mut e), true) => { + e.insert(file_format); + } + (Entry::Occupied(_), false) => { + return config_err!( + "File type already registered for extension {ext}. Set overwrite to true to replace this extension." + ); + } }; Ok(()) } @@ -815,11 +923,8 @@ impl SessionState { &self.catalog_list } - /// set the catalog list - pub(crate) fn register_catalog_list( - &mut self, - catalog_list: Arc, - ) { + /// Set the catalog list + pub fn register_catalog_list(&mut self, catalog_list: Arc) { self.catalog_list = catalog_list; } @@ -828,6 +933,11 @@ impl SessionState { &self.scalar_functions } + /// Return reference to higher_order_functions + pub fn higher_order_functions(&self) -> &HashMap> { + &self.higher_order_functions + } + /// Return reference to aggregate_functions pub fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions @@ -909,11 +1019,14 @@ impl SessionState { /// be used for all values unless explicitly provided. /// /// See example on [`SessionState`] +#[derive(Clone)] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, #[cfg(feature = "sql")] + relation_planners: Option>>, + #[cfg(feature = "sql")] type_planner: Option>, optimizer: Option, physical_optimizers: Option, @@ -921,8 +1034,10 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, + higher_order_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, + extension_types: Option, serializer_registry: Option>, file_formats: Option>>, config: Option, @@ -931,6 +1046,8 @@ pub struct SessionStateBuilder { table_factories: Option>>, runtime_env: Option>, function_factory: Option>, + cache_factory: Option>, + statistics_registry: Option, // fields to support convenience functions analyzer_rules: Option>>, optimizer_rules: Option>>, @@ -951,6 +1068,8 @@ impl SessionStateBuilder { analyzer: None, expr_planners: None, #[cfg(feature = "sql")] + relation_planners: None, + #[cfg(feature = "sql")] type_planner: None, optimizer: None, physical_optimizers: None, @@ -958,8 +1077,10 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, + higher_order_functions: None, aggregate_functions: None, window_functions: None, + extension_types: None, serializer_registry: None, file_formats: None, table_options: None, @@ -968,6 +1089,8 @@ impl SessionStateBuilder { table_factories: None, runtime_env: None, function_factory: None, + cache_factory: None, + statistics_registry: None, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1001,6 +1124,8 @@ impl SessionStateBuilder { analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), #[cfg(feature = "sql")] + relation_planners: Some(existing.relation_planners), + #[cfg(feature = "sql")] type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), @@ -1008,10 +1133,14 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + higher_order_functions: Some( + existing.higher_order_functions.into_values().collect_vec(), + ), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), window_functions: Some(existing.window_functions.into_values().collect_vec()), + extension_types: Some(existing.extension_types), serializer_registry: Some(existing.serializer_registry), file_formats: Some(existing.file_formats.into_values().collect_vec()), config: Some(new_config), @@ -1020,7 +1149,8 @@ impl SessionStateBuilder { table_factories: Some(existing.table_factories), runtime_env: Some(existing.runtime_env), function_factory: existing.function_factory, - + cache_factory: existing.cache_factory, + statistics_registry: existing.statistics_registry, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1049,6 +1179,10 @@ impl SessionStateBuilder { .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); + self.higher_order_functions + .get_or_insert_with(Vec::new) + .extend(SessionStateDefaults::default_higher_order_functions()); + self.aggregate_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_aggregate_functions()); @@ -1057,6 +1191,11 @@ impl SessionStateBuilder { .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_window_functions()); + self.extension_types + .get_or_insert_with(|| Arc::new(MemoryExtensionTypeRegistry::new_empty())) + .extend(&SessionStateDefaults::default_extension_types()) + .expect("MemoryExtensionTypeRegistry is not read-only."); + self.table_functions .get_or_insert_with(HashMap::new) .extend( @@ -1141,6 +1280,16 @@ impl SessionStateBuilder { self } + #[cfg(feature = "sql")] + /// Sets the [`RelationPlanner`]s used to customize SQL relation planning. + pub fn with_relation_planners( + mut self, + relation_planners: Vec>, + ) -> Self { + self.relation_planners = Some(relation_planners); + self + } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. #[cfg(feature = "sql")] pub fn with_type_planner(mut self, type_planner: Arc) -> Self { @@ -1219,6 +1368,15 @@ impl SessionStateBuilder { self } + /// Set the map of [`HigherOrderUDF`]s + pub fn with_higher_order_functions( + mut self, + higher_order_functions: Vec>, + ) -> Self { + self.higher_order_functions = Some(higher_order_functions); + self + } + /// Set the map of [`AggregateUDF`]s pub fn with_aggregate_functions( mut self, @@ -1237,6 +1395,15 @@ impl SessionStateBuilder { self } + /// Sets the [`ExtensionTypeRegistry`](datafusion_expr::registry::ExtensionTypeRegistry). + pub fn with_extension_type_registry( + mut self, + registry: ExtensionTypeRegistryRef, + ) -> Self { + self.extension_types = Some(registry); + self + } + /// Set the [`SerializerRegistry`] pub fn with_serializer_registry( mut self, @@ -1309,6 +1476,25 @@ impl SessionStateBuilder { self } + /// Set a [`CacheFactory`] for custom caching strategy + pub fn with_cache_factory( + mut self, + cache_factory: Option>, + ) -> Self { + self.cache_factory = cache_factory; + self + } + + /// Set a [`StatisticsRegistry`] for pluggable statistics providers. + /// + /// The registry allows physical optimizer rules to access enhanced statistics + /// (e.g., NDV overrides, histograms) beyond what is available from + /// `ExecutionPlan::partition_statistics()`. + pub fn with_statistics_registry(mut self, registry: StatisticsRegistry) -> Self { + self.statistics_registry = Some(registry); + self + } + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] /// for more details. /// @@ -1355,6 +1541,8 @@ impl SessionStateBuilder { analyzer, expr_planners, #[cfg(feature = "sql")] + relation_planners, + #[cfg(feature = "sql")] type_planner, optimizer, physical_optimizers, @@ -1362,8 +1550,10 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, + higher_order_functions, aggregate_functions, window_functions, + extension_types, serializer_registry, file_formats, table_options, @@ -1372,6 +1562,8 @@ impl SessionStateBuilder { table_factories, runtime_env, function_factory, + cache_factory, + statistics_registry, analyzer_rules, optimizer_rules, physical_optimizer_rules, @@ -1385,6 +1577,8 @@ impl SessionStateBuilder { analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), #[cfg(feature = "sql")] + relation_planners: relation_planners.unwrap_or_default(), + #[cfg(feature = "sql")] type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), @@ -1395,8 +1589,10 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), + higher_order_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), + extension_types: Arc::new(MemoryExtensionTypeRegistry::default()), serializer_registry: serializer_registry .unwrap_or_else(|| Arc::new(EmptySerializerRegistry)), file_formats: HashMap::new(), @@ -1408,6 +1604,8 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + cache_factory, + statistics_registry, prepared_plans: HashMap::new(), }; @@ -1447,6 +1645,29 @@ impl SessionStateBuilder { } } + if let Some(higher_order_functions) = higher_order_functions { + for function in higher_order_functions { + match state.register_higher_order_function(Arc::clone(&function)) { + Ok(Some(existing)) => { + debug!( + "Overwrote existing higher-order function '{}'", + existing.name() + ); + } + Ok(None) => { + debug!("Registered higher-order function '{}'", function.name()); + } + Err(err) => { + debug!( + "Failed to register higher-order function '{}': {}", + function.name(), + err + ); + } + } + } + } + if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { let existing_udf = state.register_udaf(udaf); @@ -1465,6 +1686,10 @@ impl SessionStateBuilder { }); } + if let Some(extension_types) = extension_types { + state.extension_types = extension_types; + } + if state.config.create_default_catalog_and_schema() { let default_catalog = SessionStateDefaults::default_catalog( &state.config, @@ -1521,6 +1746,12 @@ impl SessionStateBuilder { &mut self.expr_planners } + #[cfg(feature = "sql")] + /// Returns a mutable reference to the current [`RelationPlanner`] list. + pub fn relation_planners(&mut self) -> &mut Option>> { + &mut self.relation_planners + } + /// Returns the current type_planner value #[cfg(feature = "sql")] pub fn type_planner(&mut self) -> &mut Option> { @@ -1559,6 +1790,13 @@ impl SessionStateBuilder { &mut self.scalar_functions } + /// Returns the current scalar_functions value + pub fn higher_order_functions( + &mut self, + ) -> &mut Option>> { + &mut self.higher_order_functions + } + /// Returns the current aggregate_functions value pub fn aggregate_functions(&mut self) -> &mut Option>> { &mut self.aggregate_functions @@ -1611,6 +1849,11 @@ impl SessionStateBuilder { &mut self.function_factory } + /// Returns the cache factory + pub fn cache_factory(&mut self) -> &mut Option> { + &mut self.cache_factory + } + /// Returns the current analyzer_rules value pub fn analyzer_rules( &mut self, @@ -1649,6 +1892,7 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -1661,6 +1905,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("higher_order_functions", &self.higher_order_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1695,6 +1940,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.expr_planners() } + fn get_relation_planners(&self) -> &[Arc] { + self.state.relation_planners() + } + fn get_type_planner(&self) -> Option> { if let Some(type_planner) = &self.state.type_planner { Some(Arc::clone(type_planner)) @@ -1719,20 +1968,32 @@ impl ContextProvider for SessionContextProvider<'_> { name: &str, args: Vec, ) -> datafusion_common::Result> { + use datafusion_catalog::TableFunctionArgs; + let tbl_func = self .state .table_functions .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; - let dummy_schema = DFSchema::empty(); - let simplifier = - ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); + let simplify_context = SimplifyContext::builder() + .with_config_options(Arc::clone(self.state.config_options())) + .with_query_execution_start_time( + self.state.execution_props().query_execution_start_time, + ) + .build(); + let simplifier = ExprSimplifier::new(simplify_context); + let schema = DFSchema::empty(); let args = args .into_iter() - .map(|arg| simplifier.simplify(arg)) + .map(|arg| { + simplifier + .coerce(arg, &schema) + .and_then(|e| simplifier.simplify(e)) + }) .collect::>>()?; - let provider = tbl_func.create_table_provider(&args)?; + let provider = tbl_func + .create_table_provider_with_args(TableFunctionArgs::new(&args, self.state))?; Ok(provider_as_source(provider)) } @@ -1755,6 +2016,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } + fn get_higher_order_meta(&self, name: &str) -> Option> { + self.state.higher_order_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -1764,7 +2029,7 @@ impl ContextProvider for SessionContextProvider<'_> { } fn get_variable_type(&self, variable_names: &[String]) -> Option { - use datafusion_expr::var_provider::{is_system_variables, VarType}; + use datafusion_expr::var_provider::{VarType, is_system_variables}; if variable_names.is_empty() { return None; @@ -1791,6 +2056,14 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().keys().cloned().collect() } + fn higher_order_function_names(&self) -> Vec { + self.state + .higher_order_functions() + .keys() + .cloned() + .collect() + } + fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() } @@ -1830,6 +2103,16 @@ impl FunctionRegistry for SessionState { }) } + fn higher_order_function( + &self, + name: &str, + ) -> datafusion_common::Result> { + self.higher_order_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found")) + } + fn udaf(&self, name: &str) -> datafusion_common::Result> { let result = self.aggregate_functions.get(name); @@ -1857,6 +2140,19 @@ impl FunctionRegistry for SessionState { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn register_higher_order_function( + &mut self, + function: Arc, + ) -> datafusion_common::Result>> { + function.aliases().iter().for_each(|alias| { + self.higher_order_functions + .insert(alias.clone(), Arc::clone(&function)); + }); + Ok(self + .higher_order_functions + .insert(function.name().into(), function)) + } + fn register_udaf( &mut self, udaf: Arc, @@ -1892,6 +2188,19 @@ impl FunctionRegistry for SessionState { Ok(udf) } + fn deregister_higher_order_function( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let function = self.higher_order_functions.remove(name); + if let Some(function) = &function { + for alias in function.aliases() { + self.higher_order_functions.remove(alias); + } + } + Ok(function) + } + fn deregister_udaf( &mut self, name: &str, @@ -1938,6 +2247,10 @@ impl FunctionRegistry for SessionState { Ok(()) } + fn higher_order_function_names(&self) -> HashSet { + self.higher_order_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.aggregate_functions.keys().cloned().collect() } @@ -1947,8 +2260,14 @@ impl FunctionRegistry for SessionState { } } +impl datafusion_execution::TaskContextProvider for SessionState { + fn task_ctx(&self) -> Arc { + SessionState::task_ctx(self) + } +} + impl OptimizerConfig for SessionState { - fn query_execution_start_time(&self) -> DateTime { + fn query_execution_start_time(&self) -> Option> { self.execution_props.query_execution_start_time } @@ -1974,6 +2293,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), + state.higher_order_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2000,35 +2320,6 @@ impl QueryPlanner for DefaultQueryPlanner { } } -struct SessionSimplifyProvider<'a> { - state: &'a SessionState, - df_schema: &'a DFSchema, -} - -impl<'a> SessionSimplifyProvider<'a> { - fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { - Self { state, df_schema } - } -} - -impl SimplifyInfo for SessionSimplifyProvider<'_> { - fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { - Ok(expr.get_type(self.df_schema)? == DataType::Boolean) - } - - fn nullable(&self, expr: &Expr) -> datafusion_common::Result { - expr.nullable(self.df_schema) - } - - fn execution_props(&self) -> &ExecutionProps { - self.state.execution_props() - } - - fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { - expr.get_type(self.df_schema) - } -} - #[derive(Debug)] pub(crate) struct PreparedPlan { /// Data types of the parameters @@ -2037,14 +2328,27 @@ pub(crate) struct PreparedPlan { pub(crate) plan: Arc, } +/// A [`CacheFactory`] can be registered via [`SessionState`] +/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`]. +/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`] +/// may need to be implemented to handle such plans. +pub trait CacheFactory: Debug + Send + Sync { + /// Create a logical plan for caching + fn create( + &self, + plan: LogicalPlan, + session_state: &SessionState, + ) -> datafusion_common::Result; +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; use crate::common::assert_contains; use crate::config::ConfigOptions; + use crate::datasource::MemTable; use crate::datasource::empty::EmptyTable; use crate::datasource::provider_as_source; - use crate::datasource::MemTable; use crate::execution::context::SessionState; use crate::logical_expr::planner::ExprPlanner; use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; @@ -2054,13 +2358,14 @@ mod tests { use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; - use datafusion_common::config::Dialect; use datafusion_common::DFSchema; use datafusion_common::Result; + use datafusion_common::config::Dialect; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; - use datafusion_optimizer::optimizer::OptimizerRule; + use datafusion_expr::HigherOrderUDF; use datafusion_optimizer::Optimizer; + use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; @@ -2097,6 +2402,36 @@ mod tests { assert!(sql_to_expr(&state).is_err()) } + #[test] + #[cfg(feature = "sql")] + fn test_create_logical_expr_from_sql_expr() { + let state = SessionStateBuilder::new().with_default_features().build(); + + let provider = SessionContextProvider { + state: &state, + tables: HashMap::new(), + }; + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + let dialect = state.config.options().sql_parser.dialect; + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + + for sql in ["[1,2,3]", "a > 10", "SUM(a)"] { + let sql_expr = state.sql_to_expr(sql, &dialect).unwrap(); + let from_str = query + .sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + .unwrap(); + + let sql_expr_with_alias = + state.sql_to_expr_with_alias(sql, &dialect).unwrap(); + let from_expr = state + .create_logical_expr_from_sql_expr(sql_expr_with_alias, &df_schema) + .unwrap(); + assert_eq!(from_str, from_expr); + } + } + #[test] fn test_from_existing() -> Result<()> { fn employee_batch() -> RecordBatch { @@ -2137,13 +2472,15 @@ mod tests { .table_exist("employee"); assert!(is_exist); let new_state = SessionStateBuilder::new_from_existing(session_state).build(); - assert!(new_state - .catalog_list() - .catalog(default_catalog.as_str()) - .unwrap() - .schema(default_schema.as_str()) - .unwrap() - .table_exist("employee")); + assert!( + new_state + .catalog_list() + .catalog(default_catalog.as_str()) + .unwrap() + .schema(default_schema.as_str()) + .unwrap() + .table_exist("employee") + ); // if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema let disable_create_default = @@ -2151,10 +2488,12 @@ mod tests { let without_default_state = SessionStateBuilder::new() .with_config(disable_create_default) .build(); - assert!(without_default_state - .catalog_list() - .catalog(&default_catalog) - .is_none()); + assert!( + without_default_state + .catalog_list() + .catalog(&default_catalog) + .is_none() + ); let new_state = SessionStateBuilder::new_from_existing(without_default_state).build(); assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); @@ -2338,6 +2677,10 @@ mod tests { self.state.scalar_functions().get(name).cloned() } + fn get_higher_order_meta(&self, name: &str) -> Option> { + self.state.higher_order_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -2358,6 +2701,14 @@ mod tests { self.state.scalar_functions().keys().cloned().collect() } + fn higher_order_function_names(&self) -> Vec { + self.state + .higher_order_functions() + .keys() + .cloned() + .collect() + } + fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 62a575541a5d8..5e85c1bbc5e9e 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -17,6 +17,7 @@ use crate::catalog::listing_schema::ListingSchemaProvider; use crate::catalog::{CatalogProvider, TableProviderFactory}; +use crate::datasource::file_format::FileFormatFactory; use crate::datasource::file_format::arrow::ArrowFormatFactory; #[cfg(feature = "avro")] use crate::datasource::file_format::avro::AvroFormatFactory; @@ -24,7 +25,6 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::json::JsonFormatFactory; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormatFactory; -use crate::datasource::file_format::FileFormatFactory; use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] @@ -36,7 +36,8 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::registry::ExtensionTypeRegistrationRef; +use datafusion_expr::{AggregateUDF, HigherOrderUDF, ScalarUDF, WindowUDF}; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -103,7 +104,7 @@ impl SessionStateDefaults { /// returns the list of default [`ScalarUDF`]s pub fn default_scalar_functions() -> Vec> { - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_mut))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_mut))] let mut functions: Vec> = functions::all_default_functions(); #[cfg(feature = "nested_expressions")] @@ -112,6 +113,15 @@ impl SessionStateDefaults { functions } + /// returns the list of default [`HigherOrderUDF`]s + pub fn default_higher_order_functions() -> Vec> { + #[cfg(feature = "nested_expressions")] + return functions_nested::all_default_higher_order_functions(); + + #[cfg(not(feature = "nested_expressions"))] + return Vec::new(); + } + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() @@ -122,6 +132,13 @@ impl SessionStateDefaults { functions_window::all_default_window_functions() } + /// Returns the list of default extension types. + /// + /// For now, we do not register any extension types by default. + pub fn default_extension_types() -> Vec { + vec![] + } + /// returns the list of default [`TableFunction`]s pub fn default_table_functions() -> Vec> { functions_table::all_default_table_functions() @@ -155,7 +172,7 @@ impl SessionStateDefaults { } /// registers all the builtin array functions - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_variables))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_variables))] pub fn register_array_functions(state: &mut SessionState) { // register crate of array expressions (if enabled) #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 381dd5e9e8482..3170f4be7f683 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -35,6 +35,9 @@ ) )] #![warn(missing_docs, clippy::needless_borrow)] +// Use `allow` instead of `expect` for test configuration to explicitly +// disable the lint for all test code rather than expecting violations +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! [DataFusion] is an extensible query engine written in Rust that //! uses [Apache Arrow] as its in-memory format. DataFusion's target users are @@ -358,7 +361,7 @@ //! [`TreeNode`]: datafusion_common::tree_node::TreeNode //! [`tree_node module`]: datafusion_expr::logical_plan::tree_node //! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier -//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs //! //! ### Physical Plans //! @@ -647,7 +650,7 @@ //! //! [Tokio]: https://tokio.rs //! [`Runtime`]: tokio::runtime::Runtime -//! [thread_pools example]: https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/thread_pools.rs +//! [thread_pools example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/thread_pools.rs //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec @@ -758,14 +761,13 @@ //! [`RecordBatch`]: arrow::array::RecordBatch //! [`RecordBatchReader`]: arrow::record_batch::RecordBatchReader //! [`Array`]: arrow::array::Array - -/// DataFusion crate version -pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); +#![doc = include_str!("optimizer_rule_reference.md")] extern crate core; - #[cfg(feature = "sql")] extern crate sqlparser; +/// DataFusion crate version +pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); pub mod dataframe; pub mod datasource; @@ -783,7 +785,10 @@ pub use object_store; pub use parquet; #[cfg(feature = "avro")] -pub use datafusion_datasource_avro::apache_avro; +pub use datafusion_datasource_avro::arrow_avro; + +#[cfg(test)] +mod optimizer_rule_reference; // re-export DataFusion sub-crates at the top level. Use `pub use *` // so that the contents of the subcrates appears in rustdocs @@ -1177,8 +1182,56 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/upgrading.md", - library_user_guide_upgrading + "../../../docs/source/library-user-guide/upgrading/46.0.0.md", + library_user_guide_upgrading_46_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/47.0.0.md", + library_user_guide_upgrading_47_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.0.md", + library_user_guide_upgrading_48_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.1.md", + library_user_guide_upgrading_48_0_1 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/49.0.0.md", + library_user_guide_upgrading_49_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/50.0.0.md", + library_user_guide_upgrading_50_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/51.0.0.md", + library_user_guide_upgrading_51_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/52.0.0.md", + library_user_guide_upgrading_52_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/53.0.0.md", + library_user_guide_upgrading_53_0_0 ); #[cfg(doctest)] diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md new file mode 100644 index 0000000000000..fcbb200c71624 --- /dev/null +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -0,0 +1,93 @@ + + +## Built-in Optimizer Rules + +DataFusion applies a default analyzer, logical optimizer, and physical +optimizer pipeline. + +The rule names listed here match the names shown by `EXPLAIN VERBOSE`. + +Rule order matters. The default pipeline may change between releases. + +### Analyzer Rules + +| order | rule | summary | +| ----- | --------------------------- | --------------------------------------------------------------------------------------- | +| 1 | `resolve_grouping_function` | Rewrites `GROUPING(...)` calls into expressions over DataFusion's internal grouping id. | +| 2 | `type_coercion` | Adds implicit casts so operators and functions receive valid input types. | + +### Logical Optimizer Rules + +| order | rule | summary | +| ----- | ----------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | +| 1 | `rewrite_set_comparison` | Rewrites `ANY` and `ALL` set-comparison subqueries into `EXISTS`-based boolean expressions with correct SQL NULL semantics. | +| 2 | `optimize_unions` | Flattens nested unions and removes unions with a single input. | +| 3 | `simplify_expressions` | Constant-folds and simplifies expressions while preserving output names. | +| 4 | `replace_distinct_aggregate` | Rewrites `DISTINCT` and `DISTINCT ON` operators into aggregate-based plans that later rules can optimize further. | +| 5 | `eliminate_join` | Replaces keyless inner joins with a literal `false` filter by an empty relation. | +| 6 | `decorrelate_predicate_subquery` | Converts eligible `IN` and `EXISTS` predicate subqueries into semi or anti joins. | +| 7 | `scalar_subquery_to_join` | Rewrites eligible scalar subqueries into joins and adds schema-preserving projections. | +| 8 | `decorrelate_lateral_join` | Rewrites eligible lateral joins into regular joins. | +| 9 | `extract_equijoin_predicate` | Splits join filters into equijoin keys and residual predicates. | +| 10 | `eliminate_duplicated_expr` | Removes duplicate expressions from projections, aggregates, and similar operators. | +| 11 | `eliminate_filter` | Drops always-true filters and replaces always-false or NULL filters with empty relations. | +| 12 | `eliminate_cross_join` | Uses filter predicates to replace cross joins with inner joins when join keys can be found. | +| 13 | `eliminate_limit` | Removes no-op limits and simplifies trivial limit shapes. | +| 14 | `propagate_empty_relation` | Pushes empty-relation knowledge upward so operators fed by no rows collapse early. | +| 15 | `filter_null_join_keys` | Adds `IS NOT NULL` filters to nullable equijoin keys that can never match. | +| 16 | `eliminate_outer_join` | Rewrites outer joins to inner joins when later filters reject the NULL-extended rows. | +| 17 | `push_down_limit` | Moves literal limits closer to scans and unions and merges adjacent limits. | +| 18 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | +| 19 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | +| 20 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | +| 21 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | +| 22 | `extract_leaf_expressions` | Pulls cheap leaf expressions closer to data sources so later pruning and filter rules can act earlier. | +| 23 | `push_down_leaf_projections` | Pushes the helper projections created by leaf extraction toward leaf inputs. | +| 24 | `optimize_projections` | Prunes unused columns and removes unnecessary logical projections. | + +### Physical Optimizer Rules + +The same rule name may appear more than once when the default pipeline runs it +in multiple phases. + +| order | rule | phase | summary | +| ----- | ------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------ | +| 1 | `OutputRequirements` | add phase | Adds helper nodes so output requirements survive later physical rewrites. | +| 2 | `aggregate_statistics` | - | Uses exact source statistics to answer some aggregates without scanning data. | +| 3 | `join_selection` | - | Chooses join implementation, build side, and partition mode from statistics and stream properties. | +| 4 | `LimitedDistinctAggregation` | - | Pushes limit hints into grouped distinct-style aggregations when only a small result is needed. | +| 5 | `FilterPushdown` | pre-optimization phase | Pushes supported physical filters down toward data sources before distribution and sorting are enforced. | +| 6 | `EnforceDistribution` | - | Adds repartitioning only where needed to satisfy physical distribution requirements. | +| 7 | `CombinePartialFinalAggregate` | - | Collapses adjacent partial and final aggregates when the distributed shape makes them redundant. | +| 8 | `EnforceSorting` | - | Adds or removes local sorts to satisfy required input orderings. | +| 9 | `OptimizeAggregateOrder` | - | Updates aggregate expressions to use the best ordering once sort requirements are known. | +| 10 | `WindowTopN` | - | Replaces eligible row-number window and filter patterns with per-partition TopK execution. | +| 11 | `ProjectionPushdown` | early pass | Pushes projections toward inputs before later physical rewrites add more limit and TopK structure. | +| 12 | `OutputRequirements` | remove phase | Removes the temporary output-requirement helper nodes after requirement-sensitive planning is done. | +| 13 | `LimitAggregation` | - | Passes a limit hint into eligible aggregations so they can keep fewer accumulator buckets. | +| 14 | `LimitPushPastWindows` | - | Pushes fetch limits through bounded window operators when doing so keeps the result correct. | +| 15 | `HashJoinBuffering` | - | Adds buffering on the probe side of hash joins so probing can start before build completion. | +| 16 | `LimitPushdown` | - | Moves physical limits into child operators or fetch-enabled variants to cut data early. | +| 17 | `TopKRepartition` | - | Pushes TopK below hash repartition when the partition key is a prefix of the sort key. | +| 18 | `ProjectionPushdown` | late pass | Runs projection pushdown again after limit and TopK rewrites expose new pruning opportunities. | +| 19 | `PushdownSort` | - | Pushes sort requirements into data sources that can already return sorted output. | +| 20 | `EnsureCooperative` | - | Wraps non-cooperative plan parts so long-running tasks yield fairly. | +| 21 | `FilterPushdown(Post)` | post-optimization phase | Pushes dynamic filters at the end of optimization, after plan references stop moving. | +| 22 | `SanityCheckPlan` | - | Validates that the final physical plan meets ordering, distribution, and infinite-input safety requirements. | diff --git a/datafusion/core/src/optimizer_rule_reference.rs b/datafusion/core/src/optimizer_rule_reference.rs new file mode 100644 index 0000000000000..64db51b290fdc --- /dev/null +++ b/datafusion/core/src/optimizer_rule_reference.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_optimizer::analyzer::Analyzer; +use datafusion_optimizer::optimizer::Optimizer; +use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; + +const OPTIMIZER_RULE_REFERENCE: &str = include_str!("optimizer_rule_reference.md"); + +fn documented_rules(section_heading: &str) -> Vec { + let mut in_section = false; + let mut names = vec![]; + + for line in OPTIMIZER_RULE_REFERENCE.lines() { + if line == section_heading { + in_section = true; + continue; + } + + if in_section && line.starts_with("### ") { + break; + } + + if !in_section || !line.starts_with('|') || line.contains("---") { + continue; + } + + let columns: Vec<_> = line.split('|').map(str::trim).collect(); + + if columns.len() < 4 || columns[1] == "order" { + continue; + } + + names.push(columns[2].trim_matches('`').to_string()); + } + + names +} + +#[test] +fn analyzer_rules_match_documented_order() { + let rules: Vec<_> = Analyzer::new() + .rules + .iter() + .map(|rule| rule.name().to_string()) + .collect(); + + assert_eq!(documented_rules("### Analyzer Rules"), rules); +} + +#[test] +fn logical_rules_match_documented_order() { + let rules: Vec<_> = Optimizer::new() + .rules + .iter() + .map(|rule| rule.name().to_string()) + .collect(); + + assert_eq!(documented_rules("### Logical Optimizer Rules"), rules); +} + +#[test] +fn physical_rules_match_documented_order() { + let rules: Vec<_> = PhysicalOptimizer::new() + .rules + .iter() + .map(|rule| rule.name().to_string()) + .collect(); + + assert_eq!(documented_rules("### Physical Optimizer Rules"), rules); +} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c280b50a9f07a..3b2c7a78e898e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -18,13 +18,13 @@ //! Planner for [`LogicalPlan`] to [`ExecutionPlan`] use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; -use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::{source_as_provider, DefaultTableSource}; +use crate::datasource::physical_plan::{FileOutputMode, FileSinkConfig}; +use crate::datasource::{DefaultTableSource, source_as_provider}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -39,7 +39,7 @@ use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::filter::FilterExecBuilder; use crate::physical_plan::joins::utils as join_utils; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, @@ -52,33 +52,42 @@ use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, - Partitioning, PhysicalExpr, WindowExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, displayable, windows, }; use crate::schema_equivalence::schema_satisfied_by; -use arrow::array::{builder::StringBuilder, RecordBatch}; +use arrow::array::{RecordBatch, builder::StringBuilder}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; +use arrow_schema::Field; use datafusion_catalog::ScanArgs; +use datafusion_common::Column; +use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::format::ExplainAnalyzeLevel; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::TableReference; +use datafusion_common::format::ExplainAnalyzeCategories; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; +use datafusion_common::{ + DFSchema, DFSchemaRef, ScalarValue, exec_err, internal_datafusion_err, internal_err, + not_impl_err, plan_err, +}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - ScalarValue, + TableReference, assert_eq_or_internal_err, assert_or_internal_err, }; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, - NullTreatment, WindowFunction, WindowFunctionParams, + AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, + WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::utils::split_conjunction; +use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, @@ -87,21 +96,22 @@ use datafusion_expr::{ use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::{ - create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, + LexOrdering, PhysicalSortExpr, create_physical_sort_exprs, }; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; -use datafusion_physical_plan::metrics::MetricType; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; -use itertools::{multiunzip, Itertools}; +use indexmap::IndexSet; +use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; @@ -151,6 +161,80 @@ pub trait ExtensionPlanner { physical_inputs: &[Arc], session_state: &SessionState, ) -> Result>>; + + /// Create a physical plan for a [`LogicalPlan::TableScan`]. + /// + /// This is useful for planning valid [`TableSource`]s that are not [`TableProvider`]s. + /// + /// Returns: + /// * `Ok(Some(plan))` if the planner knows how to plan the `scan` + /// * `Ok(None)` if the planner does not know how to plan the `scan` and wants to delegate the planning to another [`ExtensionPlanner`] + /// * `Err` if the planner knows how to plan the `scan` but errors while doing so + /// + /// # Example + /// + /// ```rust,ignore + /// use std::sync::Arc; + /// use datafusion::physical_plan::ExecutionPlan; + /// use datafusion::logical_expr::TableScan; + /// use datafusion::execution::context::SessionState; + /// use datafusion::error::Result; + /// use datafusion_physical_planner::{ExtensionPlanner, PhysicalPlanner}; + /// use async_trait::async_trait; + /// + /// // Your custom table source type + /// struct MyCustomTableSource { /* ... */ } + /// + /// // Your custom execution plan + /// struct MyCustomExec { /* ... */ } + /// + /// struct MyExtensionPlanner; + /// + /// #[async_trait] + /// impl ExtensionPlanner for MyExtensionPlanner { + /// async fn plan_extension( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// _node: &dyn UserDefinedLogicalNode, + /// _logical_inputs: &[&LogicalPlan], + /// _physical_inputs: &[Arc], + /// _session_state: &SessionState, + /// ) -> Result>> { + /// Ok(None) + /// } + /// + /// async fn plan_table_scan( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// scan: &TableScan, + /// _session_state: &SessionState, + /// ) -> Result>> { + /// // Check if this is your custom table source + /// if scan.source.is::() { + /// // Create a custom execution plan for your table source + /// let exec = MyCustomExec::new( + /// scan.table_name.clone(), + /// Arc::clone(scan.projected_schema.inner()), + /// ); + /// Ok(Some(Arc::new(exec))) + /// } else { + /// // Return None to let other extension planners handle it + /// Ok(None) + /// } + /// } + /// } + /// ``` + /// + /// [`TableSource`]: datafusion_expr::TableSource + /// [`TableProvider`]: datafusion_catalog::TableProvider + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + _scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } } /// Default single node physical query planner that converts a @@ -272,7 +356,8 @@ struct LogicalNode<'a> { impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to - /// plan user-defined logical nodes [`LogicalPlan::Extension`]. + /// plan user-defined logical nodes [`LogicalPlan::Extension`] + /// or user-defined table sources in [`LogicalPlan::TableScan`]. /// The planner uses the first [`ExtensionPlanner`] to return a non-`None` /// plan. pub fn with_extension_planners( @@ -281,8 +366,111 @@ impl DefaultPhysicalPlanner { Self { extension_planners } } - /// Create a physical plan from a logical plan - async fn create_initial_plan( + fn ensure_schema_matches( + &self, + logical_schema: &DFSchemaRef, + physical_plan: &Arc, + context: &str, + ) -> Result<()> { + if !logical_schema.matches_arrow_schema(&physical_plan.schema()) { + return plan_err!( + "{} created an ExecutionPlan with mismatched schema. \ + LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", + context, + logical_schema, + physical_plan.schema() + ); + } + Ok(()) + } + + /// Collect uncorrelated scalar subqueries. We don't descend into nested + /// subqueries here: each call to `create_initial_plan` handles subqueries + /// at its level and then recurses in order to handle nested subqueries. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Subquery contains Arc with interior mutability but is intentionally used as hash key + fn collect_scalar_subqueries(plan: &LogicalPlan) -> Vec { + let mut subqueries = IndexSet::new(); + plan.apply(|node| { + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::ScalarSubquery(sq) = e + && sq.outer_ref_columns.is_empty() + { + subqueries.insert(sq.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + subqueries.into_iter().collect() + } + + /// Create a physical plan from a logical plan. + /// + /// Uncorrelated scalar subqueries in the plan's own expressions are + /// collected, planned as separate physical plans, and each assigned an + /// index in a shared [`ScalarSubqueryResults`] container that will hold its + /// result at execution time. The index map and shared results container are + /// registered in [`ExecutionProps`] so that [`create_physical_expr`] can + /// convert `Expr::ScalarSubquery` into [`ScalarSubqueryExpr`] nodes that + /// read from that container. + /// + /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] node + /// that executes those subquery plans before any data flows through the + /// main plan. If a subquery itself contains nested uncorrelated subqueries, + /// the recursive call produces its own [`ScalarSubqueryExec`] inside the + /// subquery plan — each level manages only its own subqueries. + /// + /// Returns a [`BoxFuture`] rather than using `async fn` because of + /// this recursion. + /// + /// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + /// [`BoxFuture`]: futures::future::BoxFuture + fn create_initial_plan<'a>( + &'a self, + logical_plan: &'a LogicalPlan, + session_state: &'a SessionState, + ) -> futures::future::BoxFuture<'a, Result>> { + Box::pin(async move { + let all_subqueries = Self::collect_scalar_subqueries(logical_plan); + let (links, index_map) = self + .plan_scalar_subqueries(all_subqueries, session_state) + .await?; + + if links.is_empty() { + return self + .create_initial_plan_inner(logical_plan, session_state) + .await; + } + + // Create the shared `ScalarSubqueryResults` container and register + // it in `ExecutionProps` so that `create_physical_expr` can resolve + // `Expr::ScalarSubquery` into `ScalarSubqueryExpr` nodes. We clone + // the `SessionState` so these are available throughout physical + // planning without mutating the caller's state. + // + // Ideally, the subquery state would live in a dedicated planning + // context rather than in `ExecutionProps`. It's here because + // `create_physical_expr` only receives `&ExecutionProps`. + let results = ScalarSubqueryResults::new(links.len()); + let mut owned = session_state.clone(); + owned.execution_props_mut().subquery_indexes = index_map; + owned.execution_props_mut().subquery_results = results.clone(); + let session_state = Cow::Owned(owned); + + let plan = self + .create_initial_plan_inner(logical_plan, &session_state) + .await?; + Ok(Arc::new(ScalarSubqueryExec::new(plan, links, results))) + }) + } + + /// Inner physical planning that converts a logical plan tree into an + /// execution plan tree without collecting scalar subqueries. + async fn create_initial_plan_inner( &self, logical_plan: &LogicalPlan, session_state: &SessionState, @@ -347,11 +535,11 @@ impl DefaultPhysicalPlanner { .flatten() .collect::>(); // Ideally this never happens if we have a valid LogicalPlan tree - if outputs.len() != 1 { - return internal_err!( - "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" - ); - } + assert_eq_or_internal_err!( + outputs.len(), + 1, + "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" + ); let plan = outputs.pop().unwrap(); Ok(plan) } @@ -447,27 +635,56 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, children: ChildrenContainer, ) -> Result> { + let execution_props = session_state.execution_props(); let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { - let source = source_as_provider(source)?; - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - let filters_vec = filters.into_iter().collect::>(); - let opts = ScanArgs::default() - .with_projection(projection.as_deref()) - .with_filters(Some(&filters_vec)) - .with_limit(*fetch); - let res = source.scan_with_args(session_state, opts).await?; - Arc::clone(res.plan()) + LogicalPlan::TableScan(scan) => { + let TableScan { + source, + projection, + filters, + fetch, + projected_schema, + .. + } = scan; + + if let Ok(source) = source_as_provider(source) { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + let filters_vec = filters.into_iter().collect::>(); + let opts = ScanArgs::default() + .with_projection(projection.as_deref()) + .with_filters(Some(&filters_vec)) + .with_limit(*fetch); + let res = source.scan_with_args(session_state, opts).await?; + Arc::clone(res.plan()) + } else { + let mut maybe_plan = None; + for planner in &self.extension_planners { + if maybe_plan.is_some() { + break; + } + + maybe_plan = + planner.plan_table_scan(self, scan, session_state).await?; + } + + let plan = match maybe_plan { + Some(plan) => plan, + None => { + return plan_err!( + "No installed planner was able to plan TableScan for custom TableSource: {:?}", + scan.table_name + ); + } + }; + let context = + format!("Extension planner for table scan {}", scan.table_name); + self.ensure_schema_matches(projected_schema, &plan, &context)?; + plan + } } LogicalPlan::Values(Values { values, schema }) => { let exprs = values @@ -475,7 +692,7 @@ impl DefaultPhysicalPlanner { .map(|row| { row.iter() .map(|expr| { - self.create_physical_expr(expr, schema, session_state) + create_physical_expr(expr, schema, execution_props) }) .collect::>>>() }) @@ -496,7 +713,7 @@ impl DefaultPhysicalPlanner { output_schema, }) => { let output_schema = Arc::clone(output_schema.inner()); - self.plan_describe(Arc::clone(schema), output_schema)? + self.plan_describe(&Arc::clone(schema), output_schema)? } // 1 Child @@ -525,16 +742,48 @@ impl DefaultPhysicalPlanner { let keep_partition_by_columns = match source_option_tuples .get("execution.keep_partition_by_columns") - .map(|v| v.trim()) { - None => session_state.config().options().execution.keep_partition_by_columns, + .map(|v| v.trim()) + { + None => { + session_state + .config() + .options() + .execution + .keep_partition_by_columns + } Some("true") => true, Some("false") => false, - Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\"" + ))); + } }; + // Parse single_file_output option if explicitly set + let file_output_mode = match source_option_tuples + .get("single_file_output") + .map(|v| v.trim()) + { + None => FileOutputMode::Automatic, + Some("true") => FileOutputMode::SingleFile, + Some("false") => FileOutputMode::Directory, + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'single_file_output' was not recognized: \"{value}\"" + ))); + } + }; + + // Filter out sink-related options that are not format options + let format_options: HashMap = source_option_tuples + .iter() + .filter(|(k, _)| k.as_str() != "single_file_output") + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let sink_format = file_type_to_format(file_type)? - .create(session_state, source_option_tuples)?; + .create(session_state, &format_options)?; // Determine extension based on format extension and compression let file_extension = match sink_format.compression_type() { @@ -555,6 +804,7 @@ impl DefaultPhysicalPlanner { insert_op: InsertOp::Append, keep_partition_by_columns, file_extension, + file_output_mode, }; let ordering = input_exec.properties().output_ordering().cloned(); @@ -573,9 +823,7 @@ impl DefaultPhysicalPlanner { op: WriteOp::Insert(insert_op), .. }) => { - if let Some(provider) = - target.as_any().downcast_ref::() - { + if let Some(provider) = target.downcast_ref::() { let input_exec = children.one()?; provider .table_provider @@ -587,18 +835,89 @@ impl DefaultPhysicalPlanner { ); } } - LogicalPlan::Window(Window { window_expr, .. }) => { - if window_expr.is_empty() { - return internal_err!("Impossibly got empty window expression"); + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Delete, + input, + .. + }) => { + if let Some(provider) = target.downcast_ref::() { + let filters = extract_dml_filters(input, table_name)?; + provider + .table_provider + .delete_from(session_state, filters) + .await + .map_err(|e| { + e.context(format!("DELETE operation on table '{table_name}'")) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); } + } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Update, + input, + .. + }) => { + if let Some(provider) = target.downcast_ref::() { + // For UPDATE, the assignments are encoded in the projection of input + // We pass the filters and let the provider handle the projection + let filters = extract_dml_filters(input, table_name)?; + // Extract assignments from the projection in input plan + let assignments = extract_update_assignments(input)?; + provider + .table_provider + .update(session_state, assignments, filters) + .await + .map_err(|e| { + e.context(format!("UPDATE operation on table '{table_name}'")) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Truncate, + .. + }) => { + if let Some(provider) = target.downcast_ref::() { + provider + .table_provider + .truncate(session_state) + .await + .map_err(|e| { + e.context(format!( + "TRUNCATE operation on table '{table_name}'" + )) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } + LogicalPlan::Window(Window { window_expr, .. }) => { + assert_or_internal_err!( + !window_expr.is_empty(), + "Impossibly got empty window expression" + ); let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { Expr::WindowFunction(window_fun) => { let WindowFunctionParams { - ref partition_by, - ref order_by, + partition_by, + order_by, .. } = &window_fun.as_ref().params; generate_sort_key(partition_by, order_by) @@ -608,8 +927,8 @@ impl DefaultPhysicalPlanner { match &**expr { Expr::WindowFunction(window_fun) => { let WindowFunctionParams { - ref partition_by, - ref order_by, + partition_by, + order_by, .. } = &window_fun.as_ref().params; generate_sort_key(partition_by, order_by) @@ -622,23 +941,17 @@ impl DefaultPhysicalPlanner { let sort_keys = get_sort_keys(&window_expr[0])?; if window_expr.len() > 1 { debug_assert!( - window_expr[1..] - .iter() - .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), - "all window expressions shall have the same sort keys, as guaranteed by logical planning" - ); + window_expr[1..] + .iter() + .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), + "all window expressions shall have the same sort keys, as guaranteed by logical planning" + ); } let logical_schema = node.schema(); let window_expr = window_expr .iter() - .map(|e| { - create_window_expr( - e, - logical_schema, - session_state.execution_props(), - ) - }) + .map(|e| create_window_expr(e, logical_schema, execution_props)) .collect::>>()?; let can_repartition = session_state.config().target_partitions() > 1 @@ -683,6 +996,17 @@ impl DefaultPhysicalPlanner { ) { let mut differences = Vec::new(); + + if physical_input_schema.metadata() + != physical_input_schema_from_logical.metadata() + { + differences.push(format!( + "schema metadata differs: (physical) {:?} vs (logical) {:?}", + physical_input_schema.metadata(), + physical_input_schema_from_logical.metadata() + )); + } + if physical_input_schema.fields().len() != physical_input_schema_from_logical.fields().len() { @@ -712,18 +1036,27 @@ impl DefaultPhysicalPlanner { if physical_field.is_nullable() && !logical_field.is_nullable() { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } + if physical_field.metadata() != logical_field.metadata() { + differences.push(format!( + "field metadata at index {} [{}]: (physical) {:?} vs (logical) {:?}", + i, + physical_field.name(), + physical_field.metadata(), + logical_field.metadata() + )); + } } - return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences - .iter() - .map(|s| format!("\n\t- {s}")) - .join("")); + return internal_err!( + "Physical input schema should be the same as the one converted from logical input schema. Differences: {}", + differences.iter().map(|s| format!("\n\t- {s}")).join("") + ); } let groups = self.create_grouping_physical_expr( group_expr, logical_input_schema, &physical_input_schema, - session_state, + execution_props, )?; let agg_filter = aggr_expr @@ -733,7 +1066,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - session_state.execution_props(), + execution_props, ) }) .collect::>>()?; @@ -776,7 +1109,7 @@ impl DefaultPhysicalPlanner { _ => { return internal_err!( "Unexpected result from try_plan_async_exprs" - ) + ); } } } @@ -827,8 +1160,8 @@ impl DefaultPhysicalPlanner { )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self - .create_project_physical_exec( - session_state, + .create_project_physical_exec_with_props( + execution_props, children.one()?, input, expr, @@ -838,9 +1171,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.schema(); - let runtime_expr = - self.create_physical_expr(predicate, input_dfschema, session_state)?; + create_physical_expr(predicate, input_dfschema, execution_props)?; let input_schema = input.schema(); let filter = match self.try_plan_async_exprs( @@ -849,7 +1181,12 @@ impl DefaultPhysicalPlanner { input_schema.as_arrow(), )? { PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { - FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + FilterExecBuilder::new( + Arc::clone(&runtime_expr[0]), + physical_input, + ) + .with_batch_size(session_state.config().batch_size()) + .build()? } PlanAsyncExpr::Async( async_map, @@ -859,20 +1196,22 @@ impl DefaultPhysicalPlanner { async_map.async_exprs, physical_input, )?; - FilterExec::try_new( + FilterExecBuilder::new( Arc::clone(&runtime_expr[0]), Arc::new(async_exec), - )? + ) // project the output columns excluding the async functions // The async functions are always appended to the end of the schema. - .with_projection(Some( - (0..input.schema().fields().len()).collect(), + .apply_projection(Some( + (0..input.schema().fields().len()).collect::>(), ))? + .with_batch_size(session_state.config().batch_size()) + .build()? } _ => { return internal_err!( "Unexpected result from try_plan_async_exprs" - ) + ); } }; @@ -881,7 +1220,9 @@ impl DefaultPhysicalPlanner { .options() .optimizer .default_filter_selectivity; - Arc::new(filter.with_default_selectivity(selectivity)?) + let filter_exec: Arc = + Arc::new(filter.with_default_selectivity(selectivity)?); + filter_exec } LogicalPlan::Repartition(Repartition { input, @@ -897,11 +1238,7 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - session_state, - ) + create_physical_expr(e, input_dfschema, execution_props) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) @@ -922,11 +1259,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_exprs = create_physical_sort_exprs( - expr, - input_dfschema, - session_state.execution_props(), - )?; + let sort_exprs = + create_physical_sort_exprs(expr, input_dfschema, execution_props)?; let Some(ordering) = LexOrdering::new(sort_exprs) else { return internal_err!( "SortExec requires at least one sort expression" @@ -935,7 +1269,14 @@ impl DefaultPhysicalPlanner { let new_sort = SortExec::new(ordering, physical_input).with_fetch(*fetch); Arc::new(new_sort) } - LogicalPlan::Subquery(_) => todo!(), + // The optimizer's decorrelation passes remove Subquery nodes + // for supported patterns. This error is hit for correlated + // patterns that the optimizer cannot (yet) decorrelate. + LogicalPlan::Subquery(_) => { + return not_impl_err!( + "Physical plan does not support undecorrelated Subquery" + ); + } LogicalPlan::SubqueryAlias(_) => children.one()?, LogicalPlan::Limit(limit) => { let input = children.one()?; @@ -1000,6 +1341,7 @@ impl DefaultPhysicalPlanner { filter, join_type, null_equality, + null_aware, schema: join_schema, .. }) => { @@ -1045,8 +1387,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_left, input, expr, @@ -1058,8 +1400,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_right, input, expr, @@ -1124,7 +1466,6 @@ impl DefaultPhysicalPlanner { // All equi-join keys are columns now, create physical join plan let left_df_schema = left.schema(); let right_df_schema = right.schema(); - let execution_props = session_state.execution_props(); let join_on = keys .iter() .map(|(l, r)| { @@ -1207,7 +1548,7 @@ impl DefaultPhysicalPlanner { let filter_df_fields = filter_df_fields .into_iter() .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) + (qualifier.cloned(), Arc::clone(field)) }) .collect(); @@ -1230,7 +1571,7 @@ impl DefaultPhysicalPlanner { let filter_expr = create_physical_expr( expr, &filter_df_schema, - session_state.execution_props(), + execution_props, )?; let column_indices = join_utils::JoinFilter::build_column_indices( left_field_indices, @@ -1251,7 +1592,7 @@ impl DefaultPhysicalPlanner { // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { - if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + if join_filter.is_none() && *join_type == JoinType::Inner { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) } else if num_range_filters == 1 @@ -1326,9 +1667,7 @@ impl DefaultPhysicalPlanner { let left_side = side_of(lhs_logical)?; let right_side = side_of(rhs_logical)?; - if matches!(left_side, Side::Both) - || matches!(right_side, Side::Both) - { + if left_side == Side::Both || right_side == Side::Both { return Ok(Arc::new(NestedLoopJoinExec::try_new( physical_left, physical_right, @@ -1352,12 +1691,12 @@ impl DefaultPhysicalPlanner { let on_left = create_physical_expr( lhs_logical, left_df_schema, - session_state.execution_props(), + execution_props, )?; let on_right = create_physical_expr( rhs_logical, right_df_schema, - session_state.execution_props(), + execution_props, )?; Arc::new(PiecewiseMergeJoinExec::try_new( @@ -1396,6 +1735,8 @@ impl DefaultPhysicalPlanner { } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && prefer_hash_join + && !*null_aware + // Null-aware joins must use CollectLeft { Arc::new(HashJoinExec::try_new( physical_left, @@ -1406,6 +1747,7 @@ impl DefaultPhysicalPlanner { None, PartitionMode::Auto, *null_equality, + *null_aware, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1417,13 +1759,19 @@ impl DefaultPhysicalPlanner { None, PartitionMode::CollectLeft, *null_equality, + *null_aware, )?) }; // If plan was mutated previously then need to create the ExecutionPlan // for the new Projection that was applied on top. if let Some((input, expr)) = new_project { - self.create_project_physical_exec(session_state, join, input, expr)? + self.create_project_physical_exec_with_props( + execution_props, + join, + input, + expr, + )? } else { join } @@ -1463,22 +1811,16 @@ impl DefaultPhysicalPlanner { } let plan = match maybe_plan { - Some(v) => Ok(v), - _ => plan_err!("No installed planner was able to convert the custom node to an execution plan: {:?}", node) - }?; - - // Ensure the ExecutionPlan's schema matches the - // declared logical schema to catch and warn about - // logic errors when creating user defined plans. - if !node.schema().matches_arrow_schema(&plan.schema()) { - return plan_err!( - "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", - node, node.schema(), plan.schema() - ); - } else { - plan - } + Some(v) => Ok(v), + _ => plan_err!( + "No installed planner was able to convert the custom node to an execution plan: {:?}", + node + ), + }?; + + let context = format!("Extension planner for {node:?}"); + self.ensure_schema_matches(node.schema(), &plan, &context)?; + plan } // Other @@ -1502,17 +1844,17 @@ impl DefaultPhysicalPlanner { LogicalPlan::Explain(_) => { return internal_err!( "Unsupported logical plan: Explain must be root of the plan" - ) + ); } LogicalPlan::Distinct(_) => { return internal_err!( "Unsupported logical plan: Distinct should be replaced to Aggregate" - ) + ); } LogicalPlan::Analyze(_) => { return internal_err!( "Unsupported logical plan: Analyze must be root of the plan" - ) + ); } }; Ok(exec_node) @@ -1523,7 +1865,7 @@ impl DefaultPhysicalPlanner { group_expr: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { @@ -1532,38 +1874,39 @@ impl DefaultPhysicalPlanner { grouping_sets, input_dfschema, input_schema, - session_state, + execution_props, ) } Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ), Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { create_rollup_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ) } expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( - self.create_physical_expr(expr, input_dfschema, session_state), + create_physical_expr(expr, input_dfschema, execution_props), physical_name(expr), ))?])), } } else if group_expr.is_empty() { // No GROUP BY clause - create empty PhysicalGroupBy - Ok(PhysicalGroupBy::new(vec![], vec![], vec![])) + // no expressions, no null expressions and no grouping expressions + Ok(PhysicalGroupBy::new(vec![], vec![], vec![], false)) } else { Ok(PhysicalGroupBy::new_single( group_expr .iter() .map(|e| { tuple_err(( - self.create_physical_expr(e, input_dfschema, session_state), + create_physical_expr(e, input_dfschema, execution_props), physical_name(e), )) }) @@ -1587,7 +1930,7 @@ fn merge_grouping_set_physical_expr( grouping_sets: &[Vec], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_groups = grouping_sets.len(); let mut all_exprs: Vec = vec![]; @@ -1601,14 +1944,14 @@ fn merge_grouping_set_physical_expr( grouping_set_expr.push(get_physical_expr_pair( expr, input_dfschema, - session_state, + execution_props, )?); null_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); } } @@ -1628,6 +1971,7 @@ fn merge_grouping_set_physical_expr( grouping_set_expr, null_exprs, merged_sets, + true, )) } @@ -1637,7 +1981,7 @@ fn create_cube_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); let num_groups = num_of_exprs * num_of_exprs; @@ -1652,10 +1996,14 @@ fn create_cube_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } let mut groups: Vec> = Vec::with_capacity(num_groups); @@ -1670,7 +2018,7 @@ fn create_cube_physical_expr( } } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS @@ -1679,7 +2027,7 @@ fn create_rollup_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); @@ -1695,10 +2043,14 @@ fn create_rollup_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } for total in 0..=num_of_exprs { @@ -1715,7 +2067,7 @@ fn create_rollup_physical_expr( groups.push(group) } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// For a given logical expr, get a properly typed NULL ScalarValue physical expression @@ -1723,10 +2075,9 @@ fn get_null_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(&expr.clone())?; let data_type = physical_expr.data_type(input_schema)?; @@ -1752,11 +2103,11 @@ fn qualify_join_schema_sides( let join_fields = join_schema.fields(); // Validate lengths - if join_fields.len() != left_fields.len() + right_fields.len() { - return internal_err!( - "Join schema field count must match left and right field count." - ); - } + assert_eq_or_internal_err!( + join_fields.len(), + left_fields.len() + right_fields.len(), + "Join schema field count must match left and right field count." + ); // Validate field names match for (i, (field, expected)) in join_fields @@ -1764,14 +2115,12 @@ fn qualify_join_schema_sides( .zip(left_fields.iter().chain(right_fields.iter())) .enumerate() { - if field.name() != expected.name() { - return internal_err!( - "Field name mismatch at index {}: expected '{}', found '{}'", - i, - expected.name(), - field.name() - ); - } + assert_eq_or_internal_err!( + field.name(), + expected.name(), + "Field name mismatch at index {}", + i + ); } // qualify sides @@ -1797,14 +2146,240 @@ fn qualify_join_schema_sides( fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } +/// Extract filter predicates from a DML input plan (DELETE/UPDATE). +/// +/// Walks the logical plan tree and collects Filter predicates and any filters +/// pushed down into TableScan nodes, splitting AND conjunctions into individual expressions. +/// +/// For UPDATE...FROM queries involving multiple tables, this function only extracts predicates +/// that reference the target table. Filters from source table scans are excluded to prevent +/// incorrect filter semantics. +/// +/// Column qualifiers are stripped so expressions can be evaluated against the TableProvider's +/// schema. Deduplication is performed because filters may appear in both Filter nodes and +/// TableScan.filters when the optimizer performs partial (Inexact) filter pushdown. +/// +/// # Parameters +/// - `input`: The logical plan tree to extract filters from (typically a DELETE or UPDATE plan) +/// - `target`: The target table reference to scope filter extraction (prevents multi-table filter leakage) +/// +/// # Returns +/// A vector of unqualified filter expressions that can be passed to the TableProvider for execution. +/// Returns an empty vector if no applicable filters are found. +/// +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key +fn extract_dml_filters( + input: &Arc, + target: &TableReference, +) -> Result> { + let mut filters = Vec::new(); + let mut allowed_refs = vec![target.clone()]; + + // First pass: collect any alias references to the target table + input.apply(|node| { + if let LogicalPlan::SubqueryAlias(alias) = node + // Check if this alias points to the target table + && let LogicalPlan::TableScan(scan) = alias.input.as_ref() + && scan.table_name.resolved_eq(target) + { + allowed_refs.push(TableReference::bare(alias.alias.to_string())); + } + Ok(TreeNodeRecursion::Continue) + })?; + + input.apply(|node| { + match node { + LogicalPlan::Filter(filter) => { + // Split AND predicates into individual expressions + for predicate in split_conjunction(&filter.predicate) { + if predicate_is_on_target_multi(predicate, &allowed_refs)? { + filters.push(predicate.clone()); + } + } + } + LogicalPlan::TableScan(TableScan { + table_name, + filters: scan_filters, + .. + }) => { + // Only extract filters from the target table scan. + // This prevents incorrect filter extraction in UPDATE...FROM scenarios + // where multiple table scans may have filters. + if table_name.resolved_eq(target) { + for filter in scan_filters { + filters.extend(split_conjunction(filter).into_iter().cloned()); + } + } + } + // Plans without filter information + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) => { + // No filters to extract from leaf/meta plans + } + // Plans with inputs (may contain filters in children) + LogicalPlan::Projection(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Union(_) + | LogicalPlan::Join(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::Subquery(_) => { + // Filter information may appear in child nodes; continue traversal + // to extract filters from Filter/TableScan nodes deeper in the plan + } + } + Ok(TreeNodeRecursion::Continue) + })?; + + // Strip qualifiers and deduplicate. This ensures: + // 1. Only target-table predicates are retained from Filter nodes + // 2. Qualifiers stripped for TableProvider compatibility + // 3. Duplicates removed (from Filter nodes + TableScan.filters) + // + // Deduplication is necessary because filters may appear in both Filter nodes + // and TableScan.filters when the optimizer performs partial (Inexact) pushdown. + let mut seen_filters = HashSet::new(); + filters + .into_iter() + .try_fold(Vec::new(), |mut deduped, filter| { + let unqualified = strip_column_qualifiers(filter).map_err(|e| { + e.context(format!( + "Failed to strip column qualifiers for DML filter on table '{target}'" + )) + })?; + if seen_filters.insert(unqualified.clone()) { + deduped.push(unqualified); + } + Ok(deduped) + }) +} + +/// Determine whether a predicate references only columns from the target table +/// or its aliases. +/// +/// Columns may be qualified with the target table name or any of its aliases. +/// Unqualified columns are also accepted as they implicitly belong to the target table. +fn predicate_is_on_target_multi( + expr: &Expr, + allowed_refs: &[TableReference], +) -> Result { + let mut columns = HashSet::new(); + expr_to_columns(expr, &mut columns)?; + + // Short-circuit on first mismatch: returns false if any column references a table not in allowed_refs. + // Columns are accepted if: + // 1. They are unqualified (no relation specified), OR + // 2. Their relation matches one of the allowed table references using resolved equality + Ok(!columns.iter().any(|column| { + column.relation.as_ref().is_some_and(|relation| { + !allowed_refs + .iter() + .any(|allowed| relation.resolved_eq(allowed)) + }) + })) +} + +/// Strip table qualifiers from column references in an expression. +/// This is needed because DML filter expressions contain qualified column names +/// (e.g., "table.column") but the TableProvider's schema only has simple names. +fn strip_column_qualifiers(expr: Expr) -> Result { + expr.transform(|e| { + if let Expr::Column(col) = &e + && col.relation.is_some() + { + // Strip the qualifier + return Ok(Transformed::yes(Expr::Column(Column::new_unqualified( + col.name.clone(), + )))); + } + Ok(Transformed::no(e)) + }) + .map(|t| t.data) +} + +/// Extract column assignments from an UPDATE input plan. +/// For UPDATE statements, the SQL planner encodes assignments as a projection +/// over the source table. This function extracts column name and expression pairs +/// from the projection. Column qualifiers are stripped from the expressions. +/// +fn extract_update_assignments(input: &Arc) -> Result> { + // The UPDATE input plan structure is: + // Projection(updated columns as expressions with aliases) + // Filter(optional WHERE clause) + // TableScan + // + // Each projected expression has an alias matching the column name + let mut assignments = Vec::new(); + + // Find the top-level projection + if let LogicalPlan::Projection(projection) = input.as_ref() { + for expr in &projection.expr { + if let Expr::Alias(alias) = expr { + // The alias name is the column name being updated + // The inner expression is the new value + let column_name = alias.name.clone(); + // Only include if it's not just a column reference to itself + // (those are columns that aren't being updated) + if !is_identity_assignment(&alias.expr, &column_name) { + // Strip qualifiers from the assignment expression + let stripped_expr = strip_column_qualifiers((*alias.expr).clone())?; + assignments.push((column_name, stripped_expr)); + } + } + } + } else { + // Try to find projection deeper in the plan + input.apply(|node| { + if let LogicalPlan::Projection(projection) = node { + for expr in &projection.expr { + if let Expr::Alias(alias) = expr { + let column_name = alias.name.clone(); + if !is_identity_assignment(&alias.expr, &column_name) { + let stripped_expr = + strip_column_qualifiers((*alias.expr).clone())?; + assignments.push((column_name, stripped_expr)); + } + } + } + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + + Ok(assignments) +} + +/// Check if an assignment is an identity assignment (column = column) +/// These are columns that are not being modified in the UPDATE +fn is_identity_assignment(expr: &Expr, column_name: &str) -> bool { + match expr { + Expr::Column(col) => col.name == column_name, + _ => false, + } +} + /// Check if window bounds are valid after schema information is available, and /// window_frame bounds are casted to the corresponding column type. /// queries like: @@ -1858,9 +2433,10 @@ pub fn create_window_expr_with_name( if !is_window_frame_bound_valid(window_frame) { return plan_err!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - window_frame.start_bound, window_frame.end_bound - ); + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, + window_frame.end_bound + ); } let window_frame = Arc::new(window_frame.clone()); @@ -2228,14 +2804,21 @@ impl DefaultPhysicalPlanner { let schema = Arc::clone(a.schema.inner()); let show_statistics = session_state.config_options().explain.show_statistics; let analyze_level = session_state.config_options().explain.analyze_level; - let metric_types = match analyze_level { - ExplainAnalyzeLevel::Summary => vec![MetricType::SUMMARY], - ExplainAnalyzeLevel::Dev => vec![MetricType::SUMMARY, MetricType::DEV], + let metric_types = analyze_level.included_types(); + let analyze_categories = session_state + .config_options() + .explain + .analyze_categories + .clone(); + let metric_categories = match analyze_categories { + ExplainAnalyzeCategories::All => None, + ExplainAnalyzeCategories::Only(cats) => Some(cats), }; Ok(Arc::new(AnalyzeExec::new( a.verbose, show_statistics, metric_types, + metric_categories, input, schema, ))) @@ -2243,6 +2826,7 @@ impl DefaultPhysicalPlanner { /// Optimize a physical plan by applying each physical optimizer, /// calling observer(plan, optimizer after each one) + #[expect(clippy::needless_pass_by_value)] pub fn optimize_physical_plan( &self, plan: Arc, @@ -2270,14 +2854,14 @@ impl DefaultPhysicalPlanner { for optimizer in optimizers { let before_schema = new_plan.schema(); new_plan = optimizer - .optimize(new_plan, session_state.config_options()) + .optimize_with_context(new_plan, session_state) .map_err(|e| { DataFusionError::Context(optimizer.name().to_string(), Box::new(e)) })?; // This only checks the schema in release build, and performs additional checks in debug mode. OptimizationInvariantChecker::new(optimizer) - .check(&new_plan, before_schema)?; + .check(&new_plan, &before_schema)?; debug!( "Optimized physical plan by {}:\n{}\n", @@ -2310,7 +2894,7 @@ impl DefaultPhysicalPlanner { // return an record_batch which describes a table's schema. fn plan_describe( &self, - table_schema: Arc, + table_schema: &Arc, output_schema: Arc, ) -> Result> { let mut column_names = StringBuilder::new(); @@ -2344,9 +2928,37 @@ impl DefaultPhysicalPlanner { Ok(mem_exec) } - fn create_project_physical_exec( + /// Build physical plans for scalar subqueries and assign each an ordinal + /// `SubqueryIndex`. Returns the links (plan + index) and a map from logical + /// `Subquery` to its index. + async fn plan_scalar_subqueries( &self, + subqueries: Vec, session_state: &SessionState, + ) -> Result<(Vec, DFHashMap)> { + let mut links = Vec::with_capacity(subqueries.len()); + let mut index_map = DFHashMap::with_capacity(subqueries.len()); + for sq in subqueries { + // Callers deduplicate, but guard against accidental double-planning. + if index_map.contains_key(&sq) { + continue; + } + let physical_plan = self + .create_initial_plan(&sq.subquery, session_state) + .await?; + let index = SubqueryIndex::new(links.len()); + links.push(ScalarSubqueryLink { + plan: physical_plan, + index, + }); + index_map.insert(sq, index); + } + Ok((links, index_map)) + } + + fn create_project_physical_exec_with_props( + &self, + execution_props: &ExecutionProps, input_exec: Arc, input: &Arc, expr: &[Expr], @@ -2385,7 +2997,7 @@ impl DefaultPhysicalPlanner { }; let physical_expr = - self.create_physical_expr(e, input_logical_schema, session_state); + create_physical_expr(e, input_logical_schema, execution_props); tuple_err((physical_expr, physical_name)) }) @@ -2513,11 +3125,14 @@ impl<'a> OptimizationInvariantChecker<'a> { pub fn check( &mut self, plan: &Arc, - previous_schema: Arc, + previous_schema: &Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != previous_schema { - internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { + internal_err!( + "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {}, got new schema: {}", self.rule.name(), previous_schema, plan.schema() @@ -2532,6 +3147,38 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. This can happen due to physical expressions knowing +/// more about their null-ness than their logical counterparts. +/// This change is allowed because for any field the non-nullable domain `F` is a strict subset +/// of the nullable domain `F ∪ { NULL }`. A physical schema that guarantees a stricter subset +/// of values will not violate any assumptions made based on the less strict schema. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; @@ -2574,17 +3221,16 @@ impl<'n> TreeNodeVisitor<'n> for InvariantChecker { #[cfg(test)] mod tests { - use std::any::Any; use std::cmp::Ordering; use std::fmt::{self, Debug}; use std::ops::{BitAnd, Not}; use super::*; - use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; + use crate::datasource::file_format::options::CsvReadOptions; use crate::physical_plan::{ - expressions, DisplayAs, DisplayFormatType, PlanProperties, - SendableRecordBatchStream, + DisplayAs, DisplayFormatType, PlanProperties, SendableRecordBatchStream, + expressions, }; use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; @@ -2595,12 +3241,14 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + DFSchemaRef, TableReference, ToDFSchema as _, assert_batches_eq, assert_contains, }; - use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; + use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::builder::subquery_alias; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{ + LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit, + }; use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; @@ -2627,6 +3275,16 @@ mod tests { .await } + async fn plan_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.create_physical_plan().await + } + + async fn collect_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.collect().await + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -2650,6 +3308,132 @@ mod tests { Ok(()) } + #[tokio::test] + async fn scalar_subquery_in_sort_expr_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_sort_expr_executes() -> Result<()> { + let batches = collect_sql( + "SELECT x \ + FROM (VALUES (2), (1), (3)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) DESC", + ) + .await?; + + assert_batches_eq!( + &[ + "+---+", "| x |", "+---+", "| 3 |", "| 2 |", "| 1 |", "+---+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_plans() -> Result<()> { + let plan = plan_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_executes() -> Result<()> { + let batches = collect_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) AS s \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 43 |", "+----+",], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_join_on_plans() -> Result<()> { + let plan = plan_sql( + "SELECT l.x, r.y \ + FROM (VALUES (1), (2)) AS l(x) \ + JOIN (VALUES (11), (12)) AS r(y) \ + ON l.x + (SELECT 10) = r.y", + ) + .await?; + + let formatted = format!("{plan:?}"); + assert_contains!(&formatted, "ScalarSubqueryExec"); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_mixed_correlated_and_uncorrelated_executes() -> Result<()> { + let query = "SELECT t.x, \ + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) + \ + (SELECT count(*) FROM (VALUES (1), (1), (2)) AS v(z) WHERE v.z = t.x) AS total \ + FROM (VALUES (1), (2), (3)) AS t(x) \ + ORDER BY x"; + let plan = plan_sql(query).await?; + + let formatted = format!("{plan:?}"); + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + + let batches = collect_sql(query).await?; + assert_batches_eq!( + &[ + "+---+-------+", + "| x | total |", + "+---+-------+", + "| 1 | 22 |", + "| 2 | 21 |", + "| 3 | 20 |", + "+---+-------+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_projection_and_filter_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) \ + FROM (VALUES (2), (1)) AS t(x) \ + WHERE x > (SELECT min(y) FROM (VALUES (0), (1)) AS v(y))", + ) + .await?; + + let formatted = format!("{plan:?}"); + // All uncorrelated scalar subqueries are hoisted to a single root node. + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + Ok(()) + } + #[tokio::test] async fn test_create_cube_expr() -> Result<()> { let logical_plan = test_csv_scan().await?.build()?; @@ -2667,7 +3451,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(cube, @r#" @@ -2773,6 +3557,7 @@ mod tests { true, ], ], + has_grouping_set: true, }, ) "#); @@ -2797,7 +3582,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(rollup, @r#" @@ -2883,6 +3668,7 @@ mod tests { false, ], ], + has_grouping_set: true, }, ) "#); @@ -3000,8 +3786,7 @@ mod tests { .create_physical_plan(&logical_plan, &session_state) .await; - let expected_error = - "No installed planner was able to convert the custom node to an execution plan: NoOp"; + let expected_error = "No installed planner was able to convert the custom node to an execution plan: NoOp"; match plan { Ok(_) => panic!("Expected planning failure"), Err(e) => assert!( @@ -3033,21 +3818,17 @@ mod tests { } #[tokio::test] - async fn in_list_types() -> Result<()> { - // expression: "a in ('a', 1)" + async fn in_list_types_mixed_string_int_error() -> Result<()> { + // expression: "c1 in ('a', 1)" where c1 is Utf8 let list = vec![lit("a"), lit(1i64)]; let logical_plan = test_csv_scan() .await? - // filter clause needs the type coercion rule applied .filter(col("c12").lt(lit(0.05)))? .project(vec![col("c1").in_list(list, false)])? .build()?; - let execution_plan = plan(&logical_plan).await?; - // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - - let expected = r#"expr: BinaryExpr { left: BinaryExpr { left: Column { name: "c1", index: 0 }, op: Eq, right: Literal { value: Utf8("a"), field: Field { name: "lit", data_type: Utf8 } }, fail_on_overflow: false }"#; + let e = plan(&logical_plan).await.unwrap_err().to_string(); - assert_contains!(format!("{execution_plan:?}"), expected); + assert_contains!(&e, "Cannot cast string 'a' to value of Int64 type"); Ok(()) } @@ -3067,7 +3848,7 @@ mod tests { assert_contains!( &e, - r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": Boolean), Utf8]"# + r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": non-null Boolean), Utf8]"# ); Ok(()) @@ -3092,7 +3873,6 @@ mod tests { let execution_plan = plan(&logical_plan).await?; let final_hash_agg = execution_plan - .as_any() .downcast_ref::() .expect("hash aggregate"); assert_eq!( @@ -3120,7 +3900,6 @@ mod tests { let execution_plan = plan(&logical_plan).await?; let final_hash_agg = execution_plan - .as_any() .downcast_ref::() .expect("hash aggregate"); assert_eq!( @@ -3255,21 +4034,30 @@ mod tests { .unwrap(); let plan = plan(&logical_plan).await.unwrap(); - if let Some(plan) = plan.as_any().downcast_ref::() { + if let Some(plan) = plan.downcast_ref::() { let stringified_plans = plan.stringified_plans(); assert!(stringified_plans.len() >= 4); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalLogicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::InitialPhysicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::OptimizedPhysicalPlan { .. }))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalPhysicalPlan))); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::FinalLogicalPlan) + ); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::InitialPhysicalPlan) + ); + assert!( + stringified_plans.iter().any(|p| matches!( + p.plan_type, + PlanType::OptimizedPhysicalPlan { .. } + )) + ); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::FinalPhysicalPlan) + ); } else { panic!( "Plan was not an explain plan: {}", @@ -3314,7 +4102,7 @@ mod tests { .handle_explain(&explain, &ctx.state()) .await .unwrap(); - if let Some(plan) = plan.as_any().downcast_ref::() { + if let Some(plan) = plan.downcast_ref::() { let stringified_plans = plan.stringified_plans(); assert_eq!(stringified_plans.len(), 1); assert_eq!(stringified_plans[0].plan.as_str(), "Test Err"); @@ -3412,13 +4200,15 @@ mod tests { #[derive(Debug)] struct NoOpExecutionPlan { - cache: PlanProperties, + cache: Arc, } impl NoOpExecutionPlan { fn new(schema: SchemaRef) -> Self { let cache = Self::compute_properties(schema); - Self { cache } + Self { + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -3452,11 +4242,7 @@ mod tests { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -3478,6 +4264,20 @@ mod tests { ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } // Produces an execution plan where the schema is mismatched from @@ -3604,13 +4404,10 @@ digraph { fn schema(&self) -> SchemaRef { Arc::new(Schema::empty()) } - fn as_any(&self) -> &dyn Any { - unimplemented!() - } fn children(&self) -> Vec<&Arc> { self.0.iter().collect::>() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3620,6 +4417,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for OkExtensionNode { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3636,8 +4439,12 @@ digraph { } fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => plan_err!("extension node failed it's user-defined always-invariant check"), - InvariantLevel::Executable => panic!("the OptimizationInvariantChecker should not be checking for executableness"), + InvariantLevel::Always => plan_err!( + "extension node failed it's user-defined always-invariant check" + ), + InvariantLevel::Executable => panic!( + "the OptimizationInvariantChecker should not be checking for executableness" + ), } } fn schema(&self) -> SchemaRef { @@ -3649,13 +4456,10 @@ digraph { ) -> Result> { unimplemented!() } - fn as_any(&self) -> &dyn Any { - unimplemented!() - } fn children(&self) -> Vec<&Arc> { unimplemented!() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3665,6 +4469,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for InvariantFailsExtensionNode { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3706,24 +4516,26 @@ digraph { // Test: check should pass with same schema let equal_schema = ok_plan.schema(); - OptimizationInvariantChecker::new(&rule).check(&ok_plan, equal_schema)?; + OptimizationInvariantChecker::new(&rule).check(&ok_plan, &equal_schema)?; // Test: should fail with schema changed let different_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&ok_plan, different_schema) + .check(&ok_plan, &different_schema) .unwrap_err(); assert!(expected_err.to_string().contains("PhysicalOptimizer rule 'OptimizerRuleWithSchemaCheck' failed. Schema mismatch. Expected original schema")); // Test: should fail when extension node fails it's own invariant check let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&failing_node, ok_plan.schema()) + .check(&failing_node, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); // Test: should fail when descendent extension node fails let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); @@ -3732,11 +4544,13 @@ digraph { Arc::clone(&child), ])?; let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&invalid_plan, ok_plan.schema()) + .check(&invalid_plan, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); Ok(()) } @@ -3766,13 +4580,10 @@ digraph { ) -> Result> { unimplemented!() } - fn as_any(&self) -> &dyn Any { - unimplemented!() - } fn children(&self) -> Vec<&Arc> { vec![] } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3782,6 +4593,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for ExecutableInvariantFails { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3857,8 +4674,8 @@ digraph { let right = LogicalPlanBuilder::scan("right", source, None)?.build()?; let join_keys = ( - vec![datafusion_common::Column::new(Some("left"), "a")], - vec![datafusion_common::Column::new(Some("right"), "a")], + vec![Column::new(Some("left"), "a")], + vec![Column::new(Some("right"), "a")], ); let join = left.join(right, JoinType::Full, join_keys, None)?.build()?; @@ -3879,4 +4696,293 @@ digraph { Ok(()) } + + // --- Tests for aggregate schema mismatch error messages --- + + use crate::catalog::TableProvider; + use datafusion_catalog::Session; + use datafusion_expr::TableType; + + /// A TableProvider that returns schemas for logical planning vs physical planning. + /// Used to test schema mismatch error messages. + #[derive(Debug)] + struct MockSchemaTableProvider { + logical_schema: SchemaRef, + physical_schema: SchemaRef, + } + + #[async_trait] + impl TableProvider for MockSchemaTableProvider { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.logical_schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(NoOpExecutionPlan::new(Arc::clone( + &self.physical_schema, + )))) + } + } + + /// Attempts to plan a query with potentially mismatched schemas. + async fn plan_with_schemas( + logical_schema: SchemaRef, + physical_schema: SchemaRef, + query: &str, + ) -> Result> { + let provider = MockSchemaTableProvider { + logical_schema, + physical_schema, + }; + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(provider)).unwrap(); + + ctx.sql(query).await.unwrap().create_physical_plan().await + } + + #[tokio::test] + // When schemas match, planning proceeds past the schema_satisfied_by check. + // It then panics on unimplemented error in NoOpExecutionPlan. + #[should_panic(expected = "NoOpExecutionPlan")] + async fn test_aggregate_schema_check_passes() { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + + plan_with_schemas( + Arc::clone(&schema), + schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new( + Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "schema metadata differs"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_count() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "Different number of fields"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_name() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "different_name", + DataType::Int32, + false, + )])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field name at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_type() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int64, false)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field data type at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_nullability() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field nullability at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field metadata at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_multiple() { + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Utf8, false), + ])); + let physical_schema = Arc::new( + Schema::new(vec![ + Field::new("c1", DataType::Int64, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + Field::new("c2", DataType::Utf8, false), + ]) + .with_metadata(HashMap::from([( + "schema_key".into(), + "schema_value".into(), + )])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + // Verify all applicable error fragments are present + let err_str = err.to_string(); + assert_contains!(&err_str, "schema metadata differs"); + assert_contains!(&err_str, "field data type at index"); + assert_contains!(&err_str, "field nullability at index"); + assert_contains!(&err_str, "field metadata at index"); + } + + #[derive(Debug)] + struct MockTableSource { + schema: SchemaRef, + } + + impl TableSource for MockTableSource { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + struct MockTableScanExtensionPlanner; + + #[async_trait] + impl ExtensionPlanner for MockTableScanExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + _node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } + + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + if scan.source.is::() { + Ok(Some(Arc::new(EmptyExec::new(Arc::clone( + scan.projected_schema.inner(), + ))))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_table_scan_extension_planner() { + let session_state = make_session_state(); + let planner = Arc::new(MockTableScanExtensionPlanner); + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![planner]); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let table_source = Arc::new(MockTableSource { + schema: Arc::clone(&schema), + }); + let logical_plan = LogicalPlanBuilder::scan("test", table_source, None) + .unwrap() + .build() + .unwrap(); + + let plan = physical_planner + .create_physical_plan(&logical_plan, &session_state) + .await + .unwrap(); + + assert_eq!(plan.schema(), schema); + assert!(plan.is::()); + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d723620d32323..31d9d7eb471f0 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -29,15 +29,15 @@ pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, + AvroReadOptions, CsvReadOptions, JsonReadOptions, ParquetReadOptions, }; pub use datafusion_common::Column; pub use datafusion_expr::{ + Expr, expr_fn::*, lit, lit_timestamp_nano, logical_plan::{JoinType, Partitioning}, - Expr, }; pub use datafusion_functions::expr_fn::*; #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 68f83e7f1f115..717182f1d3d5b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -25,9 +25,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; +use crate::datasource::file_format::FileFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::CsvSource; use crate::datasource::{MemTable, TableProvider}; @@ -35,28 +35,31 @@ use crate::error::Result; use crate::logical_expr::LogicalPlan; use crate::test_util::{aggr_test_schema, arrow_test_data}; +use datafusion_common::config::CsvOptions; + use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; #[cfg(feature = "compression")] use datafusion_common::DataFusionError; +use datafusion_datasource::TableSchema; use datafusion_datasource::source::DataSourceExec; -#[cfg(feature = "compression")] -use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; +#[cfg(feature = "compression")] +use bzip2::write::BzEncoder; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource_csv::partitioned_csv_config; #[cfg(feature = "compression")] +use flate2::Compression as GzCompression; +#[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] -use flate2::Compression as GzCompression; +use liblzma::write::XzEncoder; use object_store::local_unpartitioned_file; #[cfg(feature = "compression")] -use xz2::write::XzEncoder; -#[cfg(feature = "compression")] use zstd::Encoder as ZstdEncoder; pub fn create_table_dual() -> Arc { @@ -84,17 +87,26 @@ pub fn scan_partitioned_csv( let schema = aggr_test_schema(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv", arrow_test_data()); + let csv_format: Arc = Arc::new(CsvFormat::default()); + let file_groups = partitioned_file_groups( path.as_str(), filename, partitions, - Arc::new(CsvFormat::default()), + &csv_format, FileCompressionType::UNCOMPRESSED, work_dir, )?; - let source = Arc::new(CsvSource::new(true, b'"', b'"')); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(schema); + let source = Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); let config = - FileScanConfigBuilder::from(partitioned_csv_config(schema, file_groups, source)) + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .build(); Ok(DataSourceExec::from_data_source(config)) @@ -105,7 +117,7 @@ pub fn partitioned_file_groups( path: &str, filename: &str, partitions: usize, - file_format: Arc, + file_format: &Arc, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result> { @@ -189,7 +201,7 @@ pub fn partitioned_file_groups( .collect::>()) } -pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { +pub fn assert_fields_eq(plan: &LogicalPlan, expected: &[&str]) { let actual: Vec = plan .schema() .fields() diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index d31c2719973ec..62c6699f8fcd1 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -20,20 +20,21 @@ use crate::{ execution::{context::SessionState, session_state::SessionStateBuilder}, object_store::{ - memory::InMemory, path::Path, Error, GetOptions, GetResult, ListResult, - MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, - PutPayload, PutResult, + Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + memory::InMemory, path::Path, }, prelude::SessionContext, }; -use futures::{stream::BoxStream, FutureExt}; +use futures::{FutureExt, stream::BoxStream}; +use object_store::{CopyOptions, ObjectStoreExt}; use std::{ fmt::{Debug, Display, Formatter}, sync::Arc, }; use tokio::{ sync::Barrier, - time::{timeout, Duration}, + time::{Duration, timeout}, }; use url::Url; @@ -130,39 +131,40 @@ impl ObjectStore for BlockingObjectStore { location: &Path, options: GetOptions, ) -> object_store::Result { - self.inner.get_opts(location, options).await - } - - async fn head(&self, location: &Path) -> object_store::Result { - println!( - "{} received head call for {location}", - BlockingObjectStore::NAME - ); - // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. - let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; - match wait_result { - Ok(_) => println!( - "{} barrier reached for {location}", + if options.head { + println!( + "{} received head call for {location}", BlockingObjectStore::NAME - ), - Err(_) => { - let error_message = format!( - "{} barrier wait timed out for {location}", + ); + // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. + let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; + match wait_result { + Ok(_) => println!( + "{} barrier reached for {location}", BlockingObjectStore::NAME - ); - log::error!("{error_message}"); - return Err(Error::Generic { - store: BlockingObjectStore::NAME, - source: error_message.into(), - }); + ), + Err(_) => { + let error_message = format!( + "{} barrier wait timed out for {location}", + BlockingObjectStore::NAME + ); + log::error!("{error_message}"); + return Err(Error::Generic { + store: BlockingObjectStore::NAME, + source: error_message.into(), + }); + } } } + // Forward the call to the inner object store. - self.inner.head(location).await + self.inner.get_opts(location, options).await } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner.delete_stream(locations) } fn list( @@ -179,15 +181,12 @@ impl ObjectStore for BlockingObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 7149c5b0bd8ca..aad659eacbe55 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -23,8 +23,8 @@ pub mod parquet; pub mod csv; use futures::Stream; -use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::fs::File; use std::io::Write; use std::path::Path; @@ -36,16 +36,20 @@ use crate::dataframe::DataFrame; use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; +use crate::execution::session_state::CacheFactory; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; -use crate::execution::SendableRecordBatchStream; +use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; -use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_common::{DFSchemaRef, TableReference}; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType, + UserDefinedLogicalNodeCore, +}; use std::pin::Pin; use async_trait::async_trait; @@ -203,10 +207,6 @@ impl TestTableProvider {} #[async_trait] impl TableProvider for TestTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -282,3 +282,67 @@ impl RecordBatchStream for BoundedStream { self.record_batch.schema() } } + +#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Ok(Self { + input: inputs[0].clone(), + }) + } +} + +#[derive(Debug)] +struct TestCacheFactory {} + +impl CacheFactory for TestCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(datafusion_expr::Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +/// Create a test table registered to a session context with an associated cache factory +pub async fn test_table_with_cache_factory() -> Result { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(TestCacheFactory {}))) + .build(); + let ctx = SessionContext::new_with_state(session_state); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + ctx.table(name).await +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 203d9e97d2a8c..c53495421307b 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -32,17 +32,15 @@ use crate::logical_expr::execution_props::ExecutionProps; use crate::logical_expr::simplify::SimplifyContext; use crate::optimizer::simplify_expressions::ExprSimplifier; use crate::physical_expr::create_physical_expr; +use crate::physical_plan::ExecutionPlan; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; -use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; -use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource::TableSchema; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::path::Path; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -157,26 +155,20 @@ impl TestParquetFile { maybe_filter: Option, ) -> Result> { let parquet_options = ctx.copied_table_options().parquet; - let source = Arc::new(ParquetSource::new(parquet_options.clone())); - let scan_config_builder = FileScanConfigBuilder::new( - self.object_store_url.clone(), - Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile { - object_meta: self.object_meta.clone(), - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }); + let source = Arc::new( + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options.clone()), + ); + let scan_config_builder = + FileScanConfigBuilder::new(self.object_store_url.clone(), source) + .with_file(PartitionedFile::new_from_meta(self.object_meta.clone())); let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema)); + let context = SimplifyContext::builder() + .with_schema(Arc::clone(&df_schema)) + .build(); if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, &df_schema).unwrap(); @@ -184,10 +176,10 @@ impl TestParquetFile { create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; let source = Arc::new( - ParquetSource::new(parquet_options) + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options) .with_predicate(Arc::clone(&physical_filter_expr)), - ) - .with_schema(TableSchema::from_file_schema(Arc::clone(&self.schema))); + ); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -204,13 +196,12 @@ impl TestParquetFile { /// Recursively searches for DataSourceExec and returns the metrics /// on the first one it finds pub fn parquet_metrics(plan: &Arc) -> Option { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - return data_source_exec.metrics(); - } + { + return data_source_exec.metrics(); } for child in plan.children() { diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index 06ed141b2e8bd..b49183e92e387 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -26,7 +26,6 @@ use datafusion_catalog::memory::*; use datafusion_catalog::{SchemaProvider, TableProvider}; use datafusion_common::test_util::batches_to_string; use insta::assert_snapshot; -use std::any::Any; use std::sync::Arc; #[test] @@ -83,10 +82,6 @@ fn default_register_schema_not_supported() { #[derive(Debug)] struct TestProvider {} impl CatalogProvider for TestProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema_names(&self) -> Vec { unimplemented!() } @@ -116,10 +111,12 @@ async fn test_mem_provider() { assert!(provider.deregister_table(table_name).unwrap().is_none()); let test_table = EmptyTable::new(Arc::new(Schema::empty())); // register table successfully - assert!(provider - .register_table(table_name.to_string(), Arc::new(test_table)) - .unwrap() - .is_none()); + assert!( + provider + .register_table(table_name.to_string(), Arc::new(test_table)) + .unwrap() + .is_none() + ); assert!(provider.table_exist(table_name)); let other_table = EmptyTable::new(Arc::new(Schema::empty())); let result = provider.register_table(table_name.to_string(), Arc::new(other_table)); diff --git a/datafusion/core/tests/catalog_listing/mod.rs b/datafusion/core/tests/catalog_listing/mod.rs new file mode 100644 index 0000000000000..cb6cac4fb0672 --- /dev/null +++ b/datafusion/core/tests/catalog_listing/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod pruned_partition_list; diff --git a/datafusion/core/tests/catalog_listing/pruned_partition_list.rs b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs new file mode 100644 index 0000000000000..8f93dc17dbad2 --- /dev/null +++ b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_schema::DataType; +use futures::{FutureExt, StreamExt as _, TryStreamExt as _}; +use object_store::{ObjectStoreExt, memory::InMemory, path::Path}; + +use datafusion::execution::SessionStateBuilder; +use datafusion_catalog_listing::helpers::{ + describe_partition, list_partitions, pruned_partition_list, +}; +use datafusion_common::ScalarValue; +use datafusion_datasource::ListingTableUrl; +use datafusion_expr::{Expr, col, lit}; +use datafusion_session::Session; + +#[tokio::test] +async fn test_pruned_partition_list_empty() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/notparquetfile", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .collect::>() + .await; + + assert_eq!(pruned.len(), 0); +} + +#[tokio::test] +async fn test_pruned_partition_list() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/file.parquet", 100), + ("tablepath/mypartition=val2/file.parquet", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/mypartition=val1/other=val3/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/mypartition=val1/file.parquet" + ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/mypartition=val1/other=val3/file.parquet" + ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); +} + +#[tokio::test] +async fn test_pruned_partition_list_multi() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), + ]); + let filter1 = Expr::eq(col("part1"), lit("p1v2")); + let filter2 = Expr::eq(col("part2"), lit("p2v1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter1, filter2], + ".parquet", + &[ + (String::from("part1"), DataType::Utf8), + (String::from("part2"), DataType::Utf8), + ], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file1.parquet" + ); + assert_eq!( + &f1.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] + ); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file2.parquet" + ); + assert_eq!( + &f2.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] + ); +} + +#[tokio::test] +async fn test_list_partition() { + let (store, _) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), + ]); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 0, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec![]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 1, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 2, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ( + "tablepath/part1=p1v2/part2=p2v1", + 2, + vec!["file1.parquet", "file2.parquet"] + ), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), + ] + ); +} + +pub fn make_test_store_and_state( + files: &[(&str, u64)], +) -> (Arc, Arc) { + let memory = InMemory::new(); + + for (name, size) in files { + memory + .put(&Path::from(*name), vec![0; *size as usize].into()) + .now_or_never() + .unwrap() + .unwrap(); + } + + let state = SessionStateBuilder::new().build(); + (Arc::new(memory), Arc::new(state)) +} diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index 976597c8a9ac5..6375d4e25d8eb 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -20,35 +20,43 @@ use std::env; #[test] fn from_env() { - // Note: these must be a single test to avoid interference from concurrent execution - let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; - // valid testing in different cases - for bool_option in ["true", "TRUE", "True", "tRUe"] { - env::set_var(env_key, bool_option); - let config = ConfigOptions::from_env().unwrap(); - env::remove_var(env_key); - assert!(config.optimizer.filter_null_join_keys); - } + unsafe { + // Note: these must be a single test to avoid interference from concurrent execution + let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; + // valid testing in different cases + for bool_option in ["true", "TRUE", "True", "tRUe"] { + env::set_var(env_key, bool_option); + let config = ConfigOptions::from_env().unwrap(); + env::remove_var(env_key); + assert!(config.optimizer.filter_null_join_keys); + } - // invalid testing - env::set_var(env_key, "ttruee"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`"); - env::remove_var(env_key); + // invalid testing + env::set_var(env_key, "ttruee"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`" + ); + env::remove_var(env_key); - let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; + let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; - // for valid testing - env::set_var(env_key, "4096"); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 4096); + // for valid testing + env::set_var(env_key, "4096"); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 4096); - // for invalid testing - env::set_var(env_key, "abc"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string"); + // for invalid testing + env::set_var(env_key, "abc"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string" + ); - env::remove_var(env_key); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 8192); // set to its default value + env::remove_var(env_key); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 8192); // set to its default value + } } diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index edcf039e4e704..99783427f022e 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -48,18 +48,21 @@ mod optimizer; /// Run all tests that are found in the `physical_optimizer` directory mod physical_optimizer; -/// Run all tests that are found in the `schema_adapter` directory -mod schema_adapter; - /// Run all tests that are found in the `serde` directory mod serde; /// Run all tests that are found in the `catalog` directory mod catalog; +/// Run all tests that are found in the `catalog_listing` directory +mod catalog_listing; + /// Run all tests that are found in the `tracing` directory mod tracing; +/// Run all tests that are found in the `extension_types` directory +mod extension_types; + #[cfg(test)] #[ctor::ctor] fn init() { diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs new file mode 100644 index 0000000000000..24a3df7e0a8fa --- /dev/null +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -0,0 +1,806 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for DELETE, UPDATE, and TRUNCATE planning to verify filter and assignment extraction. + +use std::sync::{Arc, Mutex}; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::logical_expr::{ + Expr, LogicalPlan, TableProviderFilterPushDown, TableScan, +}; +use datafusion_catalog::Session; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; + +/// A TableProvider that captures the filters passed to delete_from(). +struct CaptureDeleteProvider { + schema: SchemaRef, + received_filters: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, +} + +impl CaptureDeleteProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, + } + } + + fn new_with_per_filter_pushdown( + schema: SchemaRef, + per_filter_pushdown: Vec, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: Some(per_filter_pushdown), + } + } + + fn captured_filters(&self) -> Option> { + self.received_filters.lock().unwrap().clone() + } +} + +impl std::fmt::Debug for CaptureDeleteProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureDeleteProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureDeleteProvider { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn delete_from( + &self, + _state: &dyn Session, + filters: Vec, + ) -> Result> { + *self.received_filters.lock().unwrap() = Some(filters); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } +} + +/// A TableProvider that captures filters and assignments passed to update(). +#[expect(clippy::type_complexity)] +struct CaptureUpdateProvider { + schema: SchemaRef, + received_filters: Arc>>>, + received_assignments: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, +} + +impl CaptureUpdateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, + } + } + + fn captured_filters(&self) -> Option> { + self.received_filters.lock().unwrap().clone() + } + + fn captured_assignments(&self) -> Option> { + self.received_assignments.lock().unwrap().clone() + } +} + +impl std::fmt::Debug for CaptureUpdateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureUpdateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureUpdateProvider { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn update( + &self, + _state: &dyn Session, + assignments: Vec<(String, Expr)>, + filters: Vec, + ) -> Result> { + *self.received_filters.lock().unwrap() = Some(filters); + *self.received_assignments.lock().unwrap() = Some(assignments); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } +} + +/// A TableProvider that captures whether truncate() was called. +struct CaptureTruncateProvider { + schema: SchemaRef, + truncate_called: Arc>, +} + +impl CaptureTruncateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + truncate_called: Arc::new(Mutex::new(false)), + } + } + + fn was_truncated(&self) -> bool { + *self.truncate_called.lock().unwrap() + } +} + +impl std::fmt::Debug for CaptureTruncateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureTruncateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureTruncateProvider { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn truncate(&self, _state: &dyn Session) -> Result> { + *self.truncate_called.lock().unwrap() = true; + + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } +} + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("status", DataType::Utf8, true), + Field::new("value", DataType::Int32, true), + ])) +} + +#[tokio::test] +async fn test_delete_single_filter() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_delete_multiple_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'x'") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_no_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t").await?.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!( + filters.is_empty(), + "DELETE without WHERE should have empty filters" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_complex_expr() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id > 5 AND (status = 'a' OR status = 'b')") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_delete_compound_filters_with_pushdown() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + // Should receive both filters, not deduplicate valid separate predicates + assert_eq!( + filters.len(), + 2, + "compound filters should not be over-suppressed" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_mixed_filter_locations() -> Result<()> { + // Test mixed-location filters: some in Filter node, some in TableScan.filters + // This happens when provider uses TableProviderFilterPushDown::Inexact, + // meaning it can push down some predicates but not others. + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Inexact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with compound WHERE clause + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + // Verify that both predicates are extracted and passed to delete_from(), + // even though they may be split between Filter node and TableScan.filters + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!( + filters.len(), + 2, + "should extract both predicates (union of Filter and TableScan.filters)" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_per_filter_pushdown_mixed_locations() -> Result<()> { + // Force per-filter pushdown decisions to exercise mixed locations in one query. + // First predicate is pushed down (Exact), second stays as residual (Unsupported). + let provider = Arc::new(CaptureDeleteProvider::new_with_per_filter_pushdown( + test_schema(), + vec![ + TableProviderFilterPushDown::Exact, + TableProviderFilterPushDown::Unsupported, + ], + )); + + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Only the first predicate should be pushed to TableScan.filters. + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Both predicates should still reach the provider (union + dedup behavior). + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 2); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain pushed-down id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain residual status filter" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_update_assignments() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("UPDATE t SET value = 100, status = 'updated' WHERE id = 5") + .await? + .collect() + .await?; + + let assignments = provider + .captured_assignments() + .expect("assignments should be captured"); + assert_eq!(assignments.len(), 2, "should have 2 assignments"); + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty(), "should have filter for WHERE clause"); + Ok(()) +} + +#[tokio::test] +async fn test_update_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("UPDATE t SET value = 100 WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Verify that the optimizer pushed down the filter into TableScan + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Execute the UPDATE and verify filters were extracted and passed to update() + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_update_filter_pushdown_passes_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("UPDATE t SET value = 42 WHERE status = 'ready'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert!( + !scan_filters.is_empty(), + "expected filter pushdown to populate TableScan filters" + ); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!( + !filters.is_empty(), + "expected filters extracted from TableScan during UPDATE" + ); + Ok(()) +} + +#[tokio::test] +async fn test_truncate_calls_provider() -> Result<()> { + let provider = Arc::new(CaptureTruncateProvider::new(test_schema())); + let config = SessionConfig::new().set( + "datafusion.optimizer.max_passes", + &ScalarValue::UInt64(Some(0)), + ); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("TRUNCATE TABLE t").await?.collect().await?; + + assert!( + provider.was_truncated(), + "truncate() should be called on the TableProvider" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_delete() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("DELETE FROM empty_t WHERE id = 1").await; + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_update() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("UPDATE empty_t SET value = 1 WHERE id = 1").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_target_table_scoping() -> Result<()> { + // Test that DELETE only extracts filters from the target table, + // not from other tables (important for DELETE...FROM safety) + let target_provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table( + "target_t", + Arc::clone(&target_provider) as Arc, + )?; + + // For now, we test single-table DELETE + // and validate that the scoping logic is correct + let df = ctx.sql("DELETE FROM target_t WHERE id > 5").await?; + df.collect().await?; + + let filters = target_provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!( + filters[0].to_string().contains("id"), + "Filter should be for id column" + ); + assert!( + filters[0].to_string().contains("5"), + "Filter should contain the value 5" + ); + Ok(()) +} + +#[tokio::test] +async fn test_update_from_drops_non_target_predicates() -> Result<()> { + // UPDATE ... FROM is currently not working + // TODO fix https://github.com/apache/datafusion/issues/19950 + let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t1", Arc::clone(&target_provider) as Arc)?; + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("status", DataType::Utf8, true), + // t2-only column to avoid false negatives after qualifier stripping + Field::new("src_only", DataType::Utf8, true), + ])); + let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema); + ctx.register_table("t2", Arc::new(source_table))?; + + let result = ctx + .sql( + "UPDATE t1 SET value = 1 FROM t2 \ + WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10", + ) + .await; + + // Verify UPDATE ... FROM is rejected with appropriate error + // TODO fix https://github.com/apache/datafusion/issues/19950 + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("UPDATE ... FROM is not supported"), + "Expected 'UPDATE ... FROM is not supported' error, got: {err}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_qualifier_stripping_and_validation() -> Result<()> { + // Test that filter qualifiers are properly stripped and validated + // Unqualified predicates should work fine + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with unqualified column reference + // (After parsing, the planner adds qualifiers, but our validation should accept them) + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty(), "Should have extracted filter"); + + // Verify qualifiers are stripped: check that Column expressions have no qualifier + let has_qualified_column = filters[0] + .exists(|expr| Ok(matches!(expr, Expr::Column(col) if col.relation.is_some())))?; + assert!( + !has_qualified_column, + "Filter should have unqualified columns after stripping" + ); + + // Also verify the string representation doesn't contain table qualifiers + let filter_str = filters[0].to_string(); + assert!( + !filter_str.contains("t.id"), + "Filter should not contain qualified column reference, got: {filter_str}" + ); + assert!( + filter_str.contains("id") || filter_str.contains("1"), + "Filter should reference id column or the value 1, got: {filter_str}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_truncate() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("TRUNCATE TABLE empty_t").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + + Ok(()) +} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index cbdc4a448ea41..cef75b444f6fe 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -28,25 +27,27 @@ use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_expr::{ - col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, + Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, col, }; use datafusion::physical_plan::{ - collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, collect, }; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; use datafusion_common::project_schema; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::PlanProperties; use async_trait::async_trait; use futures::stream::Stream; +mod dml_planning; mod provider_filter_pushdown; mod statistics; @@ -78,7 +79,7 @@ struct CustomTableProvider; #[derive(Debug, Clone)] struct CustomExecutionPlan { projection: Option>, - cache: PlanProperties, + cache: Arc, } impl CustomExecutionPlan { @@ -87,7 +88,10 @@ impl CustomExecutionPlan { let schema = project_schema(&schema, projection.as_ref()).expect("projected schema"); let cache = Self::compute_properties(schema); - Self { projection, cache } + Self { + projection, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -152,11 +156,7 @@ impl ExecutionPlan for CustomExecutionPlan { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -179,16 +179,12 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); } let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); - Ok(Statistics { + Ok(Arc::new(Statistics { num_rows: Precision::Exact(batch.num_rows()), total_byte_size: Precision::Absent, column_statistics: self @@ -207,16 +203,28 @@ impl ExecutionPlan for CustomExecutionPlan { ..Default::default() }) .collect(), - }) + })) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } #[async_trait] impl TableProvider for CustomTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { TEST_CUSTOM_SCHEMA_REF!() } @@ -316,8 +324,9 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } +#[expect(clippy::needless_pass_by_value)] fn contains_place_holder_exec(plan: Arc) -> bool { - if plan.as_any().is::() { + if plan.is::() { true } else if plan.children().len() != 1 { false diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54ba..e52c559ec79ef 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -29,13 +29,14 @@ use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{DataFusionError, internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -62,13 +63,16 @@ fn create_batch(value: i32, num_rows: usize) -> Result { #[derive(Debug)] struct CustomPlan { batches: Vec, - cache: PlanProperties, + cache: Arc, } impl CustomPlan { fn new(schema: SchemaRef, batches: Vec) -> Self { let cache = Self::compute_properties(schema); - Self { batches, cache } + Self { + batches, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -105,11 +109,7 @@ impl ExecutionPlan for CustomPlan { Self::static_name() } - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -134,16 +134,36 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { + let schema_captured = self.schema().clone(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), + futures::stream::iter(self.batches.clone().into_iter().map(move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| batch.schema().index_of(field.name()).ok()) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + })), ))) } - fn statistics(&self) -> Result { - // here we could provide more accurate statistics - // but we want to test the filter pushdown not the CBOs - Ok(Statistics::new_unknown(&self.schema())) + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -155,10 +175,6 @@ struct CustomProvider { #[async_trait] impl TableProvider for CustomProvider { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn schema(&self) -> SchemaRef { self.zero_batch.schema() } @@ -183,7 +199,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 403c04f1737e1..01c4deac5ccd3 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -17,7 +17,7 @@ //! This module contains end to end tests of statistics propagation -use std::{any::Any, sync::Arc}; +use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::execution::context::TaskContext; @@ -33,6 +33,7 @@ use datafusion::{ scalar::ScalarValue, }; use datafusion_catalog::Session; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{project_schema, stats::Precision}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -45,7 +46,7 @@ use async_trait::async_trait; struct StatisticsValidation { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsValidation { @@ -59,7 +60,7 @@ impl StatisticsValidation { Self { stats, schema, - cache, + cache: Arc::new(cache), } } @@ -76,10 +77,6 @@ impl StatisticsValidation { #[async_trait] impl TableProvider for StatisticsValidation { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -154,11 +151,7 @@ impl ExecutionPlan for StatisticsValidation { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -181,17 +174,29 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - Ok(Statistics::new_unknown(&self.schema)) + Ok(Arc::new(Statistics::new_unknown(&self.schema))) } else { - Ok(self.stats.clone()) + Ok(Arc::new(self.stats.clone())) } } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } fn init_ctx(stats: Statistics, schema: Schema) -> Result { @@ -214,6 +219,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(13), @@ -221,6 +227,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(5), + byte_size: Precision::Absent, }, ], }, @@ -240,7 +247,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.partition_statistics(None)?); + assert_eq!(stats, *physical_plan.partition_statistics(None)?); Ok(()) } @@ -257,7 +264,7 @@ async fn sql_filter() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); let stats = physical_plan.partition_statistics(None)?; - assert_eq!(stats.num_rows, Precision::Inexact(1)); + assert_eq!(stats.num_rows, Precision::Inexact(7)); Ok(()) } @@ -270,17 +277,18 @@ async fn sql_limit() -> Result<()> { let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is smaller than the original number of lines we mark the statistics as inexact + // and cap NDV at the new row count + let limit_stats = physical_plan.partition_statistics(None)?; + assert_eq!(limit_stats.num_rows, Precision::Exact(5)); + // c1: NDV=2 stays at 2 (already below limit of 5) assert_eq!( - Statistics { - num_rows: Precision::Exact(5), - column_statistics: stats - .column_statistics - .iter() - .map(|c| c.clone().to_inexact()) - .collect(), - total_byte_size: Precision::Absent - }, - physical_plan.partition_statistics(None)? + limit_stats.column_statistics[0].distinct_count, + Precision::Inexact(2) + ); + // c2: NDV=13 capped to 5 (the limit row count) + assert_eq!( + limit_stats.column_statistics[1].distinct_count, + Precision::Inexact(5) ); let df = ctx @@ -289,7 +297,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.partition_statistics(None)?); + assert_eq!(stats, *physical_plan.partition_statistics(None)?); Ok(()) } @@ -309,7 +317,7 @@ async fn sql_window() -> Result<()> { let result = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, result.num_rows); - let col_stats = result.column_statistics; + let col_stats = &result.column_statistics; assert_eq!(2, col_stats.len()); assert_eq!(stats.column_statistics[1], col_stats[0]); diff --git a/datafusion/core/tests/data/json_array.json b/datafusion/core/tests/data/json_array.json new file mode 100644 index 0000000000000..1a8716dbf4beb --- /dev/null +++ b/datafusion/core/tests/data/json_array.json @@ -0,0 +1,5 @@ +[ + {"a": 1, "b": "hello"}, + {"a": 2, "b": "world"}, + {"a": 3, "b": "test"} +] diff --git a/datafusion/core/tests/data/json_empty_array.json b/datafusion/core/tests/data/json_empty_array.json new file mode 100644 index 0000000000000..fe51488c7066f --- /dev/null +++ b/datafusion/core/tests/data/json_empty_array.json @@ -0,0 +1 @@ +[] diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow new file mode 100644 index 0000000000000..bad9e3de4a57f Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow differ diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow new file mode 100644 index 0000000000000..4a07fbfa47f32 Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow differ diff --git a/datafusion/core/tests/data/recursive_cte/closure.csv b/datafusion/core/tests/data/recursive_cte/closure.csv new file mode 100644 index 0000000000000..a31e2bfbf36b6 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/closure.csv @@ -0,0 +1,6 @@ +start,end +1,2 +2,3 +2,4 +2,4 +4,1 \ No newline at end of file diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 265862ff9af8a..2ada0411f4f8c 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{types::Int32Type, ListArray}; +use arrow::array::{ListArray, types::Int32Type}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ @@ -31,7 +31,7 @@ use datafusion::prelude::*; use datafusion_common::test_util::batches_to_string; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; -use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder}; +use datafusion_expr::{ExprSchemable, LogicalPlanBuilder, table_scan}; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; use datafusion_functions_nested::map::map; use insta::assert_snapshot; @@ -313,10 +313,10 @@ async fn test_fn_arrow_typeof() -> Result<()> { +----------------------+ | arrow_typeof(test.l) | +----------------------+ - | List(nullable Int32) | - | List(nullable Int32) | - | List(nullable Int32) | - | List(nullable Int32) | + | List(Int32) | + | List(Int32) | + | List(Int32) | + | List(Int32) | +----------------------+ "); @@ -402,7 +402,7 @@ async fn test_fn_approx_median() -> Result<()> { +-----------------------+ | approx_median(test.b) | +-----------------------+ - | 10 | + | 10.0 | +-----------------------+ "); @@ -422,7 +422,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +---------------------------------------------------------------------------+ | approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] | +---------------------------------------------------------------------------+ - | 10 | + | 10.0 | +---------------------------------------------------------------------------+ "); @@ -437,7 +437,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +----------------------------------------------------------------------------+ | approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] | +----------------------------------------------------------------------------+ - | 100 | + | 100.0 | +----------------------------------------------------------------------------+ "); @@ -457,7 +457,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +--------------------------------------------------------------------+ | approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] | +--------------------------------------------------------------------+ - | 10 | + | 10.0 | +--------------------------------------------------------------------+ " ); @@ -477,7 +477,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +---------------------------------------------------------------------+ | approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] | +---------------------------------------------------------------------+ - | 100 | + | 100.0 | +---------------------------------------------------------------------+ " ); @@ -494,7 +494,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +------------------------------------------------------------------------------------+ | approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] | +------------------------------------------------------------------------------------+ - | 30 | + | 30.25 | +------------------------------------------------------------------------------------+ "); @@ -510,7 +510,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { +-------------------------------------------------------------------------------------+ | approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] | +-------------------------------------------------------------------------------------+ - | 69 | + | 69.85 | +-------------------------------------------------------------------------------------+ "); diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs index 9bd69dfa72b4c..c61fe4fed1615 100644 --- a/datafusion/core/tests/dataframe/describe.rs +++ b/datafusion/core/tests/dataframe/describe.rs @@ -17,7 +17,7 @@ use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_common::test_util::batches_to_string; -use datafusion_common::{test_util::parquet_test_data, Result}; +use datafusion_common::{Result, test_util::parquet_test_data}; use insta::assert_snapshot; #[tokio::test] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 05f5a204c0963..e0830754399db 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -20,10 +20,10 @@ mod dataframe_functions; mod describe; use arrow::array::{ - record_batch, Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, - FixedSizeListBuilder, Float32Array, Float64Array, Int32Array, Int32Builder, - Int8Array, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, - StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, + Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, + FixedSizeListBuilder, Float32Array, Float64Array, Int8Array, Int32Array, + Int32Builder, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, + StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, record_batch, }; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{ @@ -43,6 +43,7 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; +use rstest::rstest; use std::collections::HashMap; use std::fs; use std::path::Path; @@ -56,18 +57,16 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ColumnarValue, Volatility}; -use datafusion::prelude::{ - CsvReadOptions, JoinType, NdJsonReadOptions, ParquetReadOptions, -}; +use datafusion::prelude::{CsvReadOptions, JoinType, ParquetReadOptions}; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, - test_table_with_name, + test_table_with_cache_factory, test_table_with_name, }; use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, - DataFusionError, ScalarValue, TableReference, UnnestOptions, + Constraint, Constraints, DFSchema, DataFusionError, ScalarValue, SchemaError, + TableReference, UnnestOptions, assert_contains, internal_datafusion_err, }; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; @@ -76,23 +75,24 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, cast, col, create_udf, exists, + in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, }; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; use datafusion::error::Result as DataFusionResult; +use datafusion::execution::options::JsonReadOptions; use datafusion_functions_window::expr_fn::lag; // Get string representation of the plan @@ -305,6 +305,27 @@ async fn select_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn select_columns_with_nonexistent_columns() -> Result<()> { + let t = test_table().await?; + let t2 = t.select_columns(&["canada", "c2", "rocks"]); + + match t2 { + Err(DataFusionError::SchemaError(boxed_err, _)) => { + // Verify it's the first invalid column + match boxed_err.as_ref() { + SchemaError::FieldNotFound { field, .. } => { + assert_eq!(field.name(), "canada"); + } + _ => panic!("Expected SchemaError::FieldNotFound for 'canada'"), + } + } + _ => panic!("Expected SchemaError"), + } + + Ok(()) +} + #[tokio::test] async fn select_expr() -> Result<()> { // build plan using Table API @@ -392,14 +413,14 @@ async fn select_with_periods() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c1 | +------+ | 1 | | 10 | +------+ - "### + " ); Ok(()) @@ -513,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> { async fn drop_columns_with_empty_array() -> Result<()> { // build plan using Table API let t = test_table().await?; - let t2 = t.drop_columns(&[])?; + let drop_columns = vec![] as Vec<&str>; + let t2 = t.drop_columns(&drop_columns)?; let plan = t2.logical_plan().clone(); // build query using SQL @@ -528,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> { Ok(()) } +#[tokio::test] +async fn drop_columns_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2, + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_qualified_find_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2.clone(), + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn test_find_qualified_names() -> Result<()> { + let t = test_table().await?; + let column_names = ["c1", "c2", "c3"]; + let columns = t.find_qualified_columns(&column_names)?; + + // Expected results for each column + let binding = TableReference::bare("aggregate_test_100"); + let expected = [ + (Some(&binding), "c1"), + (Some(&binding), "c2"), + (Some(&binding), "c3"), + ]; + + // Verify we got the expected number of results + assert_eq!( + columns.len(), + expected.len(), + "Expected {} columns, got {}", + expected.len(), + columns.len() + ); + + // Iterate over the results and check each one individually + for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() { + let (actual_table_ref, actual_field_ref) = actual; + let (expected_table_ref, expected_field_name) = expected; + + // Check table reference + assert_eq!( + actual_table_ref, expected_table_ref, + "Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}" + ); + + // Check field name + assert_eq!( + actual_field_ref.name(), + *expected_field_name, + "Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'" + ); + } + + Ok(()) +} + #[tokio::test] async fn drop_with_quotes() -> Result<()> { // define data with a column name that has a "." in it: @@ -547,14 +670,14 @@ async fn drop_with_quotes() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r#" +------+ | f"c2 | +------+ | 11 | | 2 | +------+ - "### + "# ); Ok(()) @@ -573,20 +696,20 @@ async fn drop_with_periods() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?; let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c2 | +------+ | 11 | | 2 | +------+ - "### + " ); Ok(()) @@ -723,23 +846,23 @@ async fn test_aggregate_with_pk() -> Result<()> { assert_snapshot!( physical_plan_to_string(&df).await, - @r###" + @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -766,9 +889,8 @@ async fn test_aggregate_with_pk2() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=Sorted - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 AND name@1 = a - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 AND name@1 = a + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -778,13 +900,13 @@ async fn test_aggregate_with_pk2() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -815,9 +937,8 @@ async fn test_aggregate_with_pk3() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=PartiallySorted([0]) - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -827,13 +948,13 @@ async fn test_aggregate_with_pk3() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -866,9 +987,8 @@ async fn test_aggregate_with_pk4() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -876,13 +996,13 @@ async fn test_aggregate_with_pk4() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | id | +----+ | 1 | +----+ - "### + " ); Ok(()) @@ -904,7 +1024,7 @@ async fn test_aggregate_alias() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c2 | +----+ @@ -914,7 +1034,7 @@ async fn test_aggregate_alias() -> Result<()> { | 5 | | 6 | +----+ - "### + " ); Ok(()) @@ -951,7 +1071,7 @@ async fn test_aggregate_with_union() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------------+ | c1 | sum_result | +----+------------+ @@ -961,7 +1081,7 @@ async fn test_aggregate_with_union() -> Result<()> { | d | 126 | | e | 121 | +----+------------+ - "### + " ); Ok(()) } @@ -987,7 +1107,7 @@ async fn test_aggregate_subexpr() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----------------+------+ | c2 + Int32(10) | sum | +----------------+------+ @@ -997,7 +1117,7 @@ async fn test_aggregate_subexpr() -> Result<()> { | 15 | 95 | | 16 | -146 | +----------------+------+ - "### + " ); Ok(()) @@ -1020,7 +1140,7 @@ async fn test_aggregate_name_collision() -> Result<()> { // The select expr has the same display_name as the group_expr, // but since they are different expressions, it should fail. .expect_err("Expected error"); - assert_snapshot!(df.strip_backtrace(), @r###"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."###); + assert_snapshot!(df.strip_backtrace(), @r#"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."#); Ok(()) } @@ -1079,33 +1199,33 @@ async fn window_using_aggregates() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df), - @r###" + @r" +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 | - | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 | - | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 | + | -85 | -101 | 14 | -12.0 | -12 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25.0 | -25 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32.75 | -34 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56.0 | -56 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29.75 | -28 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25.0 | -25 | 83 | -101 | 4 | 65 | + | -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 | + | -85 | -48 | 6 | -35.75 | -36 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37.75 | -40 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17.0 | -18 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70.5 | -70 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43.0 | -43 | 83 | -85 | 3 | -12 | + | -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 | + | -85 | 13 | 11 | -17.0 | -18 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25.0 | -25 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12.0 | -12 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11.25 | -8 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34.5 | -34 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17.0 | -18 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25.0 | -25 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ - "### + " ); Ok(()) @@ -1172,7 +1292,7 @@ async fn window_aggregates_with_filter() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---------+---------+---------+---------+---------+----+-----+ | sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val | +---------+---------+---------+---------+---------+----+-----+ @@ -1182,7 +1302,7 @@ async fn window_aggregates_with_filter() -> Result<()> { | 5 | 2.5 | 1 | 4 | 2 | 4 | 4 | | 5 | 2.5 | 1 | 4 | 2 | 5 | -1 | +---------+---------+---------+---------+---------+----+-----+ - "### + " ); Ok(()) @@ -1238,7 +1358,7 @@ async fn test_distinct_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1248,7 +1368,7 @@ async fn test_distinct_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1286,7 +1406,7 @@ async fn test_distinct_on() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1296,7 +1416,7 @@ async fn test_distinct_on() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1321,7 +1441,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1331,7 +1451,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1395,13 +1515,13 @@ async fn join_coercion_unnamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 10 | d | +----+------+ - "### + " ); Ok(()) } @@ -1420,13 +1540,13 @@ async fn join_on() -> Result<()> { [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))], )?; - assert_snapshot!(join.logical_plan(), @r###" + assert_snapshot!(join.logical_plan(), @r" Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2 Projection: a.c1, a.c2 TableScan: a Projection: b.c1, b.c2 TableScan: b - "###); + "); Ok(()) } @@ -1449,7 +1569,11 @@ async fn join_on_filter_datatype() -> Result<()> { let err = join.into_optimized_plan().unwrap_err(); assert_snapshot!( err.strip_backtrace(), - @"type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + @r" + type_coercion + caused by + Error during planning: Join condition must be boolean type, but got Utf8 + " ); Ok(()) } @@ -1627,7 +1751,9 @@ async fn register_table() -> Result<()> { let df_impl = DataFrame::new(ctx.state(), df.logical_plan().clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone().into_view())?; + let table_provider = df_impl.clone().into_view(); + assert_eq!(table_provider.table_type(), TableType::View); + ctx.register_table("test_table", table_provider)?; // pull the table out let table = ctx.table("test_table").await?; @@ -1644,7 +1770,7 @@ async fn register_table() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+-----------------------------+ | c1 | sum(aggregate_test_100.c12) | +----+-----------------------------+ @@ -1654,13 +1780,13 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+-----------------------------+ - "### + " ); // the results are the same as the results from the view, modulo the leaf table name assert_snapshot!( batches_to_sort_string(table_results), - @r###" + @r" +----+---------------------+ | c1 | sum(test_table.c12) | +----+---------------------+ @@ -1670,7 +1796,7 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+---------------------+ - "### + " ); Ok(()) } @@ -1719,7 +1845,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1730,7 +1856,7 @@ async fn with_column() -> Result<()> { | a | 3 | 14 | 17 | | a | 3 | 17 | 20 | +----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten @@ -1742,7 +1868,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite), - @r###" + @r" +-----+----+-----+-----+ | c1 | c2 | c3 | sum | +-----+----+-----+-----+ @@ -1753,7 +1879,7 @@ async fn with_column() -> Result<()> { | 17 | 3 | 14 | 17 | | 20 | 3 | 17 | 20 | +-----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten using same name as reference @@ -1765,7 +1891,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite_self), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1776,7 +1902,7 @@ async fn with_column() -> Result<()> { | a | 4 | 14 | 17 | | a | 4 | 17 | 20 | +----+----+-----+-----+ - "### + " ); Ok(()) @@ -1804,14 +1930,14 @@ async fn test_window_function_with_column() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+---+ | c1 | c2 | c3 | s | r | +----+----+-----+-----+---+ | c | 2 | 1 | 3 | 1 | | d | 5 | -40 | -35 | 2 | +----+----+-----+-----+---+ - "### + " ); Ok(()) @@ -1846,13 +1972,13 @@ async fn with_column_join_same_columns() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+ | c1 | c1 | +----+----+ | a | a | +----+----+ - "### + " ); let df_with_column = df.clone().with_column("new_column", lit(true))?; @@ -1875,7 +2001,7 @@ async fn with_column_join_same_columns() -> Result<()> { assert_snapshot!( df_with_column.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1, t2.c1, Boolean(true) AS new_column Sort: t1.c1 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -1883,20 +2009,20 @@ async fn with_column_join_same_columns() -> Result<()> { TableScan: aggregate_test_100 projection=[c1] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1] - "### + " ); let df_results = df_with_column.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+------------+ | c1 | c1 | new_column | +----+----+------------+ | a | a | true | +----+----+------------+ - "### + " ); Ok(()) @@ -1946,13 +2072,13 @@ async fn with_column_renamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(batches), - @r###" + @r" +-----+-----+-----+-------+ | one | two | c3 | total | +-----+-----+-----+-------+ | a | 3 | -72 | -69 | +-----+-----+-----+-------+ - "### + " ); Ok(()) @@ -2017,13 +2143,13 @@ async fn with_column_renamed_join() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+----+----+-----+ | c1 | c2 | c3 | c1 | c2 | c3 | +----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +----+----+-----+----+----+-----+ - "### + " ); let df_renamed = df.clone().with_column_renamed("t1.c1", "AAA")?; @@ -2046,7 +2172,7 @@ async fn with_column_renamed_join() -> Result<()> { assert_snapshot!( df_renamed.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3 Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -2054,20 +2180,20 @@ async fn with_column_renamed_join() -> Result<()> { TableScan: aggregate_test_100 projection=[c1, c2, c3] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1, c2, c3] - "### + " ); let df_results = df_renamed.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +-----+----+-----+----+----+-----+ | AAA | c2 | c3 | c1 | c2 | c3 | +-----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +-----+----+-----+----+----+-----+ - "### + " ); Ok(()) @@ -2102,13 +2228,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(res), - @r###" + @r" +---------+ | CoLuMn1 | +---------+ | a | +---------+ - "### + " ); let df_renamed = df_renamed @@ -2118,13 +2244,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_renamed), - @r###" + @r" +----+ | c1 | +----+ | a | +----+ - "### + " ); Ok(()) @@ -2162,19 +2288,19 @@ async fn describe_lookup_via_quoted_identifier() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&describe_result.clone().collect().await?), - @r###" - +------------+--------------+ - | describe | CoLu.Mn["1"] | - +------------+--------------+ - | count | 1 | - | max | a | - | mean | null | - | median | null | - | min | a | - | null_count | 0 | - | std | null | - +------------+--------------+ - "### + @r#" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "# ); Ok(()) @@ -2192,13 +2318,13 @@ async fn cast_expr_test() -> Result<()> { df.clone().show().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); Ok(()) @@ -2214,12 +2340,14 @@ async fn row_writer_resize_test() -> Result<()> { let data = RecordBatch::try_new( schema, - vec![ - Arc::new(StringArray::from(vec![ - Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), - Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), - ])) - ], + vec![Arc::new(StringArray::from(vec![ + Some( + "2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ), + Some( + "3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800", + ), + ]))], )?; let ctx = SessionContext::new(); @@ -2258,14 +2386,14 @@ async fn with_column_name() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+-------+ | f.c1 | f.c2 | +------+-------+ | 1 | hello | | 10 | hello | +------+-------+ - "### + " ); Ok(()) @@ -2301,13 +2429,13 @@ async fn cache_test() -> Result<()> { let cached_df_results = cached_df.collect().await?; assert_snapshot!( batches_to_sort_string(&cached_df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); assert_eq!(&df_results, &cached_df_results); @@ -2315,6 +2443,28 @@ async fn cache_test() -> Result<()> { Ok(()) } +#[tokio::test] +async fn cache_producer_test() -> Result<()> { + let df = test_table_with_cache_factory() + .await? + .select_columns(&["c2", "c3"])? + .limit(0, Some(1))? + .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + + let cached_df = df.clone().cache().await?; + + assert_snapshot!( + cached_df.clone().into_optimized_plan().unwrap(), + @r" + CacheNode + Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100 projection=[c2, c3], fetch=1 + " + ); + Ok(()) +} + #[tokio::test] async fn partition_aware_union() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; @@ -2584,13 +2734,13 @@ async fn filtered_aggr_with_param_values() -> Result<()> { let df_results = df?.collect().await?; assert_snapshot!( batches_to_string(&df_results), - @r###" + @r" +------------------------------------------------+ | count(table1.c2) FILTER (WHERE table1.c3 > $1) | +------------------------------------------------+ | 54 | +------------------------------------------------+ - "### + " ); Ok(()) @@ -2638,7 +2788,7 @@ async fn write_parquet_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2648,7 +2798,7 @@ async fn write_parquet_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) @@ -2696,7 +2846,7 @@ async fn write_csv_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2706,7 +2856,7 @@ async fn write_csv_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2744,7 +2894,7 @@ async fn write_json_with_order() -> Result<()> { ctx.register_json( "data", test_path.to_str().unwrap(), - NdJsonReadOptions::default().schema(&schema), + JsonReadOptions::default().schema(&schema), ) .await?; @@ -2753,7 +2903,7 @@ async fn write_json_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2763,7 +2913,7 @@ async fn write_json_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2812,7 +2962,7 @@ async fn write_table_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+ | tablecol1 | +-----------+ @@ -2822,7 +2972,7 @@ async fn write_table_with_order() -> Result<()> { | x | | z | +-----------+ - "### + " ); Ok(()) } @@ -2849,50 +2999,44 @@ async fn test_count_wildcard_on_sort() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.b, count(*) | - | | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST | - | | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | - | | TableScan: t1 projection=[b] | - | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | - | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | - | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------+ - "### + @r" + +---------------+------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+------------------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) ASC NULLS LAST | + | | Projection: t1.b, count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*)] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+------------------------------------------------------------------------------------+ + " ); assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+--------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+--------------------------------------------------------------------------------+ - | logical_plan | Sort: count(*) ASC NULLS LAST | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t1 projection=[b] | - | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | - | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+--------------------------------------------------------------------------------+ - "### + @r" + +---------------+----------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------+ + " ); Ok(()) } @@ -2910,23 +3054,22 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: count(Int64(1)) AS count(*) | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Projection: count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -2956,22 +3099,21 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -3077,15 +3219,17 @@ async fn test_count_wildcard_on_window() -> Result<()> { let df_results = ctx .table("t1") .await? - .select(vec![count_all_window() - .order_by(vec![Sort::new(col("a"), false, true)]) - .window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? + .select(vec![ + count_all_window() + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap(), + ])? .explain(false, false)? .collect() .await?; @@ -3113,30 +3257,29 @@ async fn test_count_wildcard_on_window() -> Result<()> { #[tokio::test] // Test with `repartition_sorts` disabled, causing a full resort of the data -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, - @r#" + @r" AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] - UnionExec - DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + UnionExec + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - "#); + "); Ok(()) } -#[ignore] // See https://github.com/apache/datafusion/issues/18380 #[tokio::test] // Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, - @r#" + @r" AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted SortPreservingMergeExec: [id@0 ASC NULLS LAST] AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted @@ -3144,53 +3287,7 @@ async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_reparti DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - "#); - - // 💥 Doesn't pass, and generates this plan: - // - // AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - // SortPreservingMergeExec: [id@0 ASC NULLS LAST] - // SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] - // AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] - // UnionExec - // DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet - // DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - // - // - // === Excerpt from the verbose explain === - // - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - // | plan_type | plan | - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - // | initial_physical_plan | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // ... - // | physical_plan after EnforceDistribution | OutputRequirementExec: order_by=[], dist_by=Unspecified | - // | | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | CoalescePartitionsExec | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // | | | - // | physical_plan after CombinePartialFinalAggregate | SAME TEXT AS ABOVE - // | | | - // | physical_plan after EnforceSorting | OutputRequirementExec: order_by=[], dist_by=Unspecified | - // | | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | SortPreservingMergeExec: [id@0 ASC NULLS LAST] | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // ... - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "); Ok(()) } @@ -3275,7 +3372,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" + @r" +---------------+-----------------------------------------------------+ | plan_type | plan | +---------------+-----------------------------------------------------+ @@ -3286,7 +3383,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+-----------------------------------------------------+ - "### + " ); // add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node. @@ -3301,7 +3398,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" + @r" +---------------+---------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------+ @@ -3311,7 +3408,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+---------------------------------------------------------------+ - "### + " ); Ok(()) @@ -3331,32 +3428,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+--------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+--------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+--------------------------------------------------------------------------------------------------------------------------+ " ); @@ -3388,32 +3483,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+--------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+--------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+--------------------------------------------------------------------------------------------------------------------------+ " ); @@ -3498,7 +3591,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3507,7 +3600,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) @@ -3545,7 +3638,7 @@ async fn sort_on_distinct_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3553,7 +3646,7 @@ async fn sort_on_distinct_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) } @@ -3684,14 +3777,14 @@ async fn filter_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+ | a | +------+ | true | | true | +------+ - "### + " ); Ok(()) @@ -3720,7 +3813,7 @@ async fn select_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------+ | a | +-------+ @@ -3729,7 +3822,7 @@ async fn select_with_alias_overwrite() -> Result<()> { | true | | false | +-------+ - "### + " ); Ok(()) @@ -3755,7 +3848,7 @@ async fn test_grouping_sets() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+-----+---------------+ | a | b | count(test.a) | +-----------+-----+---------------+ @@ -3771,7 +3864,7 @@ async fn test_grouping_sets() -> Result<()> { | 123AbcDef | | 1 | | 123AbcDef | 100 | 1 | +-----------+-----+---------------+ - "### + " ); Ok(()) @@ -3798,7 +3891,7 @@ async fn test_grouping_sets_count() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+-----------------+ | c1 | c2 | count(Int32(1)) | +----+----+-----------------+ @@ -3813,7 +3906,7 @@ async fn test_grouping_sets_count() -> Result<()> { | b | | 19 | | a | | 21 | +----+----+-----------------+ - "### + " ); Ok(()) @@ -3847,7 +3940,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+--------+---------------------+ | c1 | c2 | sum_c3 | avg_c3 | +----+----+--------+---------------------+ @@ -3887,7 +3980,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { | a | 2 | -46 | -15.333333333333334 | | a | 1 | -88 | -17.6 | +----+----+--------+---------------------+ - "### + " ); Ok(()) @@ -3924,25 +4017,25 @@ async fn join_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32] Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32] TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+---+----+---+---+ | a | a | b | c | b | c | +----+----+---+----+---+---+ | 1 | 3 | a | 10 | a | 1 | | 11 | 13 | c | 30 | c | 3 | +----+----+---+----+---+---+ - "### + " ); Ok(()) @@ -3969,27 +4062,27 @@ async fn right_semi_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightSemi Join: t1.a = t2.a [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-----+---+---+ | a | b | c | +-----+---+---+ | 10 | b | 2 | | 100 | d | 4 | +-----+---+---+ - "### + " ); Ok(()) @@ -4016,26 +4109,26 @@ async fn right_anti_filter_push_down() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightAnti Join: t1.a = t2.a Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+---+---+ | a | b | c | +----+---+---+ | 13 | c | 3 | | 3 | a | 1 | +----+---+---+ - "### + " ); Ok(()) @@ -4048,37 +4141,37 @@ async fn unnest_columns() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+--------------------------+ - | shape_id | points | tags | - +----------+---------------------------------+--------------------------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+---------------------------------+--------------------------+ - "###); + @r" + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; let results = df.unnest_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+------+ - | shape_id | points | tags | - +----------+---------------------------------+------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+---------------------------------+------+ - "###); + @r" + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4090,19 +4183,19 @@ async fn unnest_columns() -> Result<()> { let results = df.unnest_columns(&["points"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+--------------------------+ - | shape_id | points | tags | - +----------+----------------+--------------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: 5, y: -8} | [tag1] | - | 2 | {x: -2, y: -8} | [tag1] | - | 2 | {x: 6, y: 2} | [tag1] | - | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | - | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+----------------+--------------------------+ - "###); + @r" + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4118,27 +4211,27 @@ async fn unnest_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+------+ - | shape_id | points | tags | - +----------+----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: 5, y: -8} | tag1 | - | 2 | {x: -2, y: -8} | tag1 | - | 2 | {x: 6, y: 2} | tag1 | - | 3 | {x: -2, y: 5} | tag1 | - | 3 | {x: -2, y: 5} | tag2 | - | 3 | {x: -2, y: 5} | tag3 | - | 3 | {x: -2, y: 5} | tag4 | - | 3 | {x: -9, y: -7} | tag1 | - | 3 | {x: -9, y: -7} | tag2 | - | 3 | {x: -9, y: -7} | tag3 | - | 3 | {x: -9, y: -7} | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+----------------+------+ - "###); + @r" + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4178,7 +4271,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -4186,7 +4279,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | y | y | | z | z | +-----------------+---------+ - "### + " ); // make_array(dict_encoded_string,literal string) @@ -4206,7 +4299,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -4217,7 +4310,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | z | z | | fixed_string | z | +-----------------+---------+ - "### + " ); Ok(()) } @@ -4228,7 +4321,7 @@ async fn unnest_column_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +--------+----+ | list | id | +--------+----+ @@ -4237,7 +4330,7 @@ async fn unnest_column_nulls() -> Result<()> { | [] | C | | [3] | D | +--------+----+ - "### + " ); // Unnest, preserving nulls (row with B is preserved) @@ -4250,7 +4343,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -4259,7 +4352,7 @@ async fn unnest_column_nulls() -> Result<()> { | | B | | 3 | D | +------+----+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -4269,7 +4362,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -4277,7 +4370,7 @@ async fn unnest_column_nulls() -> Result<()> { | 2 | A | | 3 | D | +------+----+ - "### + " ); Ok(()) @@ -4294,7 +4387,7 @@ async fn unnest_fixed_list() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4305,7 +4398,7 @@ async fn unnest_fixed_list() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -4316,7 +4409,7 @@ async fn unnest_fixed_list() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4331,7 +4424,7 @@ async fn unnest_fixed_list() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4348,7 +4441,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4359,7 +4452,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -4370,7 +4463,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4383,7 +4476,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4419,7 +4512,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4430,7 +4523,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -4440,7 +4533,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4457,7 +4550,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4471,17 +4564,17 @@ async fn unnest_aggregate_columns() -> Result<()> { let results = df.select_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +--------------------------+ - | tags | - +--------------------------+ - | [tag1, tag2, tag3, tag4] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - | [tag1] | - +--------------------------+ - "### + @r" + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ + " ); let df = table_with_nested_types(NUM_ROWS).await?; @@ -4492,13 +4585,13 @@ async fn unnest_aggregate_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-------------+ | count(tags) | +-------------+ | 11 | +-------------+ - "### + " ); Ok(()) @@ -4571,7 +4664,7 @@ async fn unnest_array_agg() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4585,7 +4678,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4599,7 +4692,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------------+ | shape_id | tag_id | +----------+--------------+ @@ -4607,7 +4700,7 @@ async fn unnest_array_agg() -> Result<()> { | 2 | [21, 22, 23] | | 3 | [31, 32, 33] | +----------+--------------+ - "### + " ); // Unnesting again should produce the original batch. @@ -4623,7 +4716,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4637,7 +4730,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); Ok(()) @@ -4667,7 +4760,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4681,7 +4774,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4703,7 +4796,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32, nullable: true });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(UInt32);N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] " ); @@ -4711,7 +4804,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+ | shape_id | +----------+ @@ -4725,7 +4818,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | | 3 | +----------+ - "### + " ); Ok(()) @@ -4766,7 +4859,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4780,7 +4873,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 4 | c | | | | | d | +------+------------+------------+--------+ - "### + " ); // Test with `preserve_nulls = false`` @@ -4797,7 +4890,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4810,7 +4903,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 3 | c | | | | 4 | c | +------+------------+------------+--------+ - "### + " ); Ok(()) @@ -4839,7 +4932,7 @@ async fn unnest_non_nullable_list() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+ | c1 | +----+ @@ -4847,7 +4940,7 @@ async fn unnest_non_nullable_list() -> Result<()> { | 2 | | | +----+ - "### + " ); Ok(()) @@ -4892,7 +4985,7 @@ async fn test_read_batches() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+--------+ | id | number | +----+--------+ @@ -4905,7 +4998,7 @@ async fn test_read_batches() -> Result<()> { | 5 | 3.33 | | 5 | 6.66 | +----+--------+ - "### + " ); Ok(()) } @@ -4926,10 +5019,10 @@ async fn test_read_batches_empty() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" ++ ++ - "### + " ); Ok(()) } @@ -4978,14 +5071,14 @@ async fn consecutive_projection_same_schema() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+----+ | id | t | t2 | +----+----+----+ | 0 | | | | 1 | 10 | 10 | +----+----+----+ - "### + " ); Ok(()) @@ -5299,13 +5392,13 @@ async fn test_array_agg() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------------------------------------+ | array_agg(test.a) | +-------------------------------------+ | [abcDEF, abc123, CBAdef, 123AbcDef] | +-------------------------------------+ - "### + " ); Ok(()) @@ -5373,10 +5466,10 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { // N.B., the test is basically `SELECT 1 as a WHERE a = 3;` which returns no results. assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" ++ ++ - "### + " ); Ok(()) @@ -5425,20 +5518,20 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, @r" - Projection: Int32(3) AS $1 [$1:Null;N] + Projection: Int32(3) AS $1 [$1:Int32] EmptyRelation: rows=1 [] " ); assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +----+ | $1 | +----+ | 3 | +----+ - "### + " ); Ok(()) @@ -5505,42 +5598,45 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +-----+ | a | +-----+ | foo | +-----+ - "### + " ); Ok(()) } +#[rstest] +#[case(DataType::Utf8)] +#[case(DataType::LargeUtf8)] +#[case(DataType::Utf8View)] #[tokio::test] -async fn write_partitioned_parquet_results() -> Result<()> { - // create partitioned input file and context - let tmp_dir = TempDir::new()?; - - let ctx = SessionContext::new(); - +async fn write_partitioned_parquet_results(#[case] string_type: DataType) -> Result<()> { // Create an in memory table with schema C1 and C2, both strings let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), + Field::new("c1", string_type.clone(), false), + Field::new("c2", string_type.clone(), false), ])); - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["abc", "def"])), - Arc::new(StringArray::from(vec!["123", "456"])), - ], - )?; + let columns = [ + Arc::new(StringArray::from(vec!["abc", "def"])) as ArrayRef, + Arc::new(StringArray::from(vec!["123", "456"])) as ArrayRef, + ] + .map(|col| arrow::compute::cast(&col, &string_type).unwrap()) + .to_vec(); + + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?); // Register the table in the context + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); ctx.register_table("test", mem_table)?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); @@ -5567,16 +5663,17 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the c2 column is gone and that c1 is abc. let results = filter_df.collect().await?; + insta::allow_duplicates! { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | c1 | +-----+ | abc | +-----+ - "### - ); + " + )}; // Read the entire set of parquet files let df = ctx @@ -5589,17 +5686,19 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the df has the entire set of data let results = df.collect().await?; - assert_snapshot!( - batches_to_sort_string(&results), - @r###" + insta::allow_duplicates! { + assert_snapshot!( + batches_to_sort_string(&results), + @r" +-----+-----+ | c1 | c2 | +-----+-----+ | abc | 123 | | def | 456 | +-----+-----+ - "### - ); + " + ) + }; Ok(()) } @@ -5755,7 +5854,7 @@ async fn sparse_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5766,14 +5865,14 @@ async fn sparse_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5781,14 +5880,14 @@ async fn sparse_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5796,7 +5895,7 @@ async fn sparse_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5838,7 +5937,7 @@ async fn dense_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5849,14 +5948,14 @@ async fn dense_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5864,14 +5963,14 @@ async fn dense_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5879,7 +5978,7 @@ async fn dense_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5911,7 +6010,7 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5923,14 +6022,14 @@ async fn boolean_dictionary_as_filter() { | true | | false | +---------+ - "### + " ); let result_df = df.clone().filter(col("my_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5938,7 +6037,7 @@ async fn boolean_dictionary_as_filter() { | true | | true | +---------+ - "### + " ); // test nested dictionary @@ -5969,26 +6068,26 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | | false | +----------------+ - "### + " ); let result_df = df.clone().filter(col("my_nested_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | +----------------+ - "### + " ); } @@ -6066,11 +6165,11 @@ async fn test_alias() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: table_alias [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let df = df.select(vec![ @@ -6079,7 +6178,7 @@ async fn test_alias() -> Result<()> { ])?; assert_snapshot!( batches_to_sort_string(&df.collect().await.unwrap()), - @r###" + @r" +-----------+---------------------------------+ | a | table_alias.b + table_alias.one | +-----------+---------------------------------+ @@ -6088,7 +6187,7 @@ async fn test_alias() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+---------------------------------+ - "### + " ); Ok(()) } @@ -6118,7 +6217,7 @@ async fn test_alias_self_join() -> Result<()> { let joined = left.join(right, JoinType::Full, &["a"], &["a"], None)?; assert_snapshot!( batches_to_sort_string(&joined.collect().await.unwrap()), - @r###" + @r" +-----------+-----+-----------+-----+ | a | b | a | b | +-----------+-----+-----------+-----+ @@ -6127,7 +6226,7 @@ async fn test_alias_self_join() -> Result<()> { | abc123 | 10 | abc123 | 10 | | abcDEF | 1 | abcDEF | 1 | +-----------+-----+-----------+-----+ - "### + " ); Ok(()) } @@ -6140,14 +6239,14 @@ async fn test_alias_empty() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: [a:Utf8, b:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); assert_snapshot!( batches_to_sort_string(&df.select(vec![col("a"), col("b")])?.collect().await.unwrap()), - @r###" + @r" +-----------+-----+ | a | b | +-----------+-----+ @@ -6156,7 +6255,7 @@ async fn test_alias_empty() -> Result<()> { | abc123 | 10 | | abcDEF | 1 | +-----------+-----+ - "### + " ); Ok(()) @@ -6175,12 +6274,12 @@ async fn test_alias_nested() -> Result<()> { .into_optimized_plan()? .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: alias2 [a:Utf8, b:Int32, one:Int32] SubqueryAlias: alias1 [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test projection=[a, b] [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let select1 = df @@ -6189,7 +6288,7 @@ async fn test_alias_nested() -> Result<()> { assert_snapshot!( batches_to_sort_string(&select1.collect().await.unwrap()), - @r###" + @r" +-----------+-----------------------+ | a | alias2.b + alias2.one | +-----------+-----------------------+ @@ -6198,7 +6297,7 @@ async fn test_alias_nested() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+-----------------------+ - "### + " ); // Only the outermost alias is visible @@ -6217,7 +6316,7 @@ async fn register_non_json_file() { .register_json( "data", "tests/data/test_binary.parquet", - NdJsonReadOptions::default(), + JsonReadOptions::default(), ) .await; assert_contains!( @@ -6318,7 +6417,10 @@ async fn test_insert_into_checking() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8"); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8" + ); Ok(()) } @@ -6365,7 +6467,7 @@ async fn test_fill_null() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6373,7 +6475,7 @@ async fn test_fill_null() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) @@ -6393,7 +6495,7 @@ async fn test_fill_null_all_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6401,7 +6503,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); // Fill column "a" null values with a value that cannot be cast to Int32. @@ -6410,7 +6512,7 @@ async fn test_fill_null_all_columns() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6418,7 +6520,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) } @@ -6427,7 +6529,7 @@ async fn test_fill_null_all_columns() -> Result<()> { async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: // Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8. - // And the cast is not supported from Utf8 to Float16. + // And the cast is not supported from Binary to Float16. // Create a new schema with one field called "a" of type Float16, and setting nullable to false let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float16, false)])); @@ -6438,7 +6540,10 @@ async fn test_insert_into_casting_support() -> Result<()> { let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); session_ctx.register_table("t", initial_table.clone())?; - let mut write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + let mut write_df = session_ctx + .sql("values (x'a123'), (x'b456')") + .await + .unwrap(); write_df = write_df .clone() @@ -6450,7 +6555,10 @@ async fn test_insert_into_casting_support() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8."); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Binary." + ); // Testing case2: // Inserting query schema mismatch: Expected table field 'a' with type Utf8View, but got 'a' with type Utf8. @@ -6488,14 +6596,14 @@ async fn test_insert_into_casting_support() -> Result<()> { assert_snapshot!( batches_to_string(&res), - @r###" + @r" +------+ | a | +------+ | a123 | | b456 | +------+ - "### + " ); Ok(()) } @@ -6631,13 +6739,13 @@ async fn test_copy_to_preserves_order() -> Result<()> { // Expect that input to the DataSinkExec is sorted correctly assert_snapshot!( physical_plan_format, - @r###" + @r" UnionExec DataSinkExec: sink=CsvSink(file_groups=[]) SortExec: expr=[column1@0 DESC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[1] DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); Ok(()) } @@ -6743,3 +6851,50 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +/// Regression test for https://github.com/apache/datafusion/issues/21411 +/// grouping() should work when wrapped in an alias via the DataFrame API. +/// +/// This bug only manifests through the DataFrame API because `.alias()` wraps +/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression +/// level. The SQL planner handles aliasing separately (via projection), so the +/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper +/// around the aggregate function in SQL queries — making SQL-based tests +/// insufficient to cover this case. +#[tokio::test] +async fn test_grouping_with_alias() -> Result<()> { + use datafusion_functions_aggregate::expr_fn::grouping; + + let df = create_test_table("test") + .await? + .aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + + let expected = [ + "+-----------+---+", + "| a | g |", + "+-----------+---+", + "| 123AbcDef | 0 |", + "| CBAdef | 0 |", + "| abc123 | 0 |", + "| abcDEF | 0 |", + "+-----------+---+", + ]; + assert_batches_eq!(expected, &results); + + // Also verify that nested aliases (e.g. .alias("x").alias("g")) work correctly + let df = create_test_table("test") + .await? + .aggregate( + vec![col("a")], + vec![grouping(col("a")).alias("x").alias("g")], + )? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/datasource/object_store_access.rs b/datafusion/core/tests/datasource/object_store_access.rs index f89ca9e049147..83b84f6f9284e 100644 --- a/datafusion/core/tests/datasource/object_store_access.rs +++ b/datafusion/core/tests/datasource/object_store_access.rs @@ -27,17 +27,21 @@ use arrow::array::{ArrayRef, Int32Array, RecordBatch}; use async_trait::async_trait; use bytes::Bytes; -use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionContext}; +use datafusion::prelude::{ + CsvReadOptions, JsonReadOptions, ParquetReadOptions, SessionContext, +}; use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; use datafusion_datasource::ListingTableUrl; use datafusion_datasource_csv::CsvFormat; +use datafusion_datasource_json::JsonFormat; use futures::stream::BoxStream; use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ - GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, }; use parking_lot::Mutex; use std::fmt; @@ -54,8 +58,8 @@ async fn create_single_csv_file() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=csv_table.csv - - GET path=csv_table.csv + - GET (opts) path=csv_table.csv head=true + - GET (opts) path=csv_table.csv " ); } @@ -76,7 +80,7 @@ async fn query_single_csv_file() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 2 - - HEAD path=csv_table.csv + - GET (opts) path=csv_table.csv head=true - GET (opts) path=csv_table.csv " ); @@ -91,15 +95,15 @@ async fn create_multi_file_csv_file() { RequestCountingObjectStore() Total Requests: 4 - LIST prefix=data - - GET path=data/file_0.csv - - GET path=data/file_1.csv - - GET path=data/file_2.csv + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv " ); } #[tokio::test] -async fn query_multi_csv_file() { +async fn multi_query_multi_file_csv_file() { let test = Test::new().with_multi_file_csv().await; assert_snapshot!( test.query("select * from csv_table").await, @@ -117,6 +121,56 @@ async fn query_multi_csv_file() { +---------+-------+-------+ ------- Object Store Request Summary ------- RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); + + // Force a cache eviction by removing the data limit for the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"0K\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // Then re-enable the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"1M\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // this query should list the table since the cache entries were evicted + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() Total Requests: 4 - LIST prefix=data - GET (opts) path=data/file_0.csv @@ -124,6 +178,114 @@ async fn query_multi_csv_file() { - GET (opts) path=data/file_2.csv " ); + + // this query should not list the table since the entries were added in the previous query + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn query_multi_csv_file() { + let test = Test::new().with_multi_file_csv().await; + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +/// Test that a CSV file split into byte ranges via repartitioning exercises +/// range-based object store access. +/// +/// With a single file and `target_partitions=3`, the repartitioner produces +/// exactly 3 ranges. For each range, `calculate_range` calls +/// `find_first_newline` via a GET for every non-file boundary it touches +/// (the start boundary if `start > 0`, the end boundary if `end < file_size`), +/// plus one GET for the actual data — so 2 GETs for the first range (end scan +/// + data), 3 for the middle range (start scan + end scan + data), and 2 for +/// the last range (start scan + data) = 7 data GETs total. Additionally, +/// adjacent ranges share a boundary position, so each shared boundary is scanned +/// twice — once as the left range's end and again as the right range's start — +/// producing the duplicate GETs visible in the snapshot. Add the 1 HEAD for +/// file-size metadata = **8 total**. +/// +/// This differs from the JSON reader which uses [`AlignedBoundaryStream`] and +/// needs only 1 GET per range. +/// +/// This test documents the current request pattern to catch regressions. +#[tokio::test] +async fn query_csv_file_with_byte_range_partitions() { + let test = Test::new().with_single_file_csv_for_range_test().await; + // Lower the repartition_file_min_size threshold so the small test file gets + // split, and set target_partitions=3 to produce exactly 3 ranges. + test.query("SET datafusion.optimizer.repartition_file_min_size = 0") + .await; + test.query("SET datafusion.execution.target_partitions = 3") + .await; + assert_snapshot!( + test.query("select * from csv_range_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.00001 | 1e-12 | false | + | 0.00002 | 2e-12 | false | + | 0.00003 | 3e-12 | false | + | 0.00004 | 4e-12 | false | + | 0.00005 | 5e-12 | false | + | 0.00006 | 6e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 8 + - GET (opts) path=csv_range_table.csv head=true + - GET (opts) path=csv_range_table.csv range=42-129 + - GET (opts) path=csv_range_table.csv range=0-49 + - GET (opts) path=csv_range_table.csv range=42-129 + - GET (opts) path=csv_range_table.csv range=85-129 + - GET (opts) path=csv_range_table.csv range=49-89 + - GET (opts) path=csv_range_table.csv range=85-129 + - GET (opts) path=csv_range_table.csv range=89-129 + " + ); } #[tokio::test] @@ -145,17 +307,7 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 13 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 3 - GET (opts) path=data/a=1/b=10/c=100/file_1.csv - GET (opts) path=data/a=2/b=20/c=200/file_2.csv - GET (opts) path=data/a=3/b=30/c=300/file_3.csv @@ -174,10 +326,7 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 4 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 + Total Requests: 1 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -194,17 +343,7 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 1 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -221,17 +360,7 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 1 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -248,9 +377,7 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 3 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 + Total Requests: 1 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -267,22 +394,354 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 1 - GET (opts) path=data/a=1/b=10/c=100/file_1.csv " ); } +// ===================================================================== +// JSON (NDJSON) tests — mirrors the CSV tests above +// ===================================================================== + +#[tokio::test] +async fn create_single_json_file() { + let test = Test::new().with_single_file_json().await; + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=json_table.json head=true + - GET (opts) path=json_table.json + " + ); +} + +#[tokio::test] +async fn query_single_json_file() { + let test = Test::new().with_single_file_json().await; + assert_snapshot!( + test.query("select * from json_table").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.00001 | 5e-12 | true | + | 0.00002 | 4e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=json_table.json head=true + - GET (opts) path=json_table.json + " + ); +} + +#[tokio::test] +async fn create_multi_file_json() { + let test = Test::new().with_multi_file_json().await; + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.json + - GET (opts) path=data/file_1.json + - GET (opts) path=data/file_2.json + " + ); +} + +#[tokio::test] +async fn multi_query_multi_file_json() { + let test = Test::new().with_multi_file_json().await; + assert_snapshot!( + test.query("select * from json_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.json + - GET (opts) path=data/file_1.json + - GET (opts) path=data/file_2.json + " + ); + + // Force a cache eviction by removing the data limit for the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"0K\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // Then re-enable the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"1M\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // this query should list the table since the cache entries were evicted + assert_snapshot!( + test.query("select * from json_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.json + - GET (opts) path=data/file_1.json + - GET (opts) path=data/file_2.json + " + ); + + // this query should not list the table since the entries were added in the previous query + assert_snapshot!( + test.query("select * from json_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.json + - GET (opts) path=data/file_1.json + - GET (opts) path=data/file_2.json + " + ); +} + +#[tokio::test] +async fn query_multi_json_file() { + let test = Test::new().with_multi_file_json().await; + assert_snapshot!( + test.query("select * from json_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.json + - GET (opts) path=data/file_1.json + - GET (opts) path=data/file_2.json + " + ); +} + +#[tokio::test] +async fn query_partitioned_json_file() { + let test = Test::new().with_partitioned_json().await; + assert_snapshot!( + test.query("select * from json_table_partitioned").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00001 | 1e-12 | true | 1 | 10 | 100 | + | 0.00003 | 5e-12 | false | 1 | 10 | 100 | + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + | 0.00003 | 3e-12 | true | 3 | 30 | 300 | + | 0.00003 | 5e-12 | false | 3 | 30 | 300 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/a=1/b=10/c=100/file_1.json + - GET (opts) path=data/a=2/b=20/c=200/file_2.json + - GET (opts) path=data/a=3/b=30/c=300/file_3.json + " + ); + + assert_snapshot!( + test.query("select * from json_table_partitioned WHERE a=2").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.json + " + ); + + assert_snapshot!( + test.query("select * from json_table_partitioned WHERE b=20").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.json + " + ); + + assert_snapshot!( + test.query("select * from json_table_partitioned WHERE c=200").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.json + " + ); + + assert_snapshot!( + test.query("select * from json_table_partitioned WHERE a=2 AND b=20").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.json + " + ); + + assert_snapshot!( + test.query("select * from json_table_partitioned WHERE a<2 AND b=10 AND c=100").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00001 | 1e-12 | true | 1 | 10 | 100 | + | 0.00003 | 5e-12 | false | 1 | 10 | 100 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=1/b=10/c=100/file_1.json + " + ); +} + +/// Test that a JSON file split into byte ranges via repartitioning produces +/// exactly one GET request per byte range — no extra requests for boundary seeking. +/// +/// With a single file and `target_partitions=3`, the repartitioner produces +/// exactly 3 ranges. Each range is served by a single [`AlignedBoundaryStream`] +/// which issues exactly one bounded `get_opts` call, so there are 3 data GETs +/// plus 1 HEAD (to determine file size) = **4 total**. +/// +/// This differs from the CSV reader, which needs multiple GETs per range. +/// +/// This test documents the current request pattern to catch regressions. +#[tokio::test] +async fn query_json_file_with_byte_range_partitions() { + let test = Test::new().with_single_file_json_for_range_test().await; + // Lower the repartition_file_min_size threshold so the small test file gets + // split, and set target_partitions=3 to produce exactly 3 ranges. + test.query("SET datafusion.optimizer.repartition_file_min_size = 0") + .await; + test.query("SET datafusion.execution.target_partitions = 3") + .await; + assert_snapshot!( + test.query("select * from json_range_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+------+ + | c1 | c2 | c3 | + +---------+-------+------+ + | 0.00001 | 1e-12 | true | + | 0.00002 | 2e-12 | true | + | 0.00003 | 3e-12 | true | + | 0.00004 | 4e-12 | true | + | 0.00005 | 5e-12 | true | + | 0.00006 | 6e-12 | true | + +---------+-------+------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - GET (opts) path=json_range_table.json head=true + - GET (opts) path=json_range_table.json range=0-216 + - GET (opts) path=json_range_table.json range=71-216 + - GET (opts) path=json_range_table.json range=143-216 + " + ); +} + #[tokio::test] async fn create_single_parquet_file_default() { // The default metadata size hint is 512KB @@ -295,8 +754,8 @@ async fn create_single_parquet_file_default() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=0-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=0-2994 " ); } @@ -314,8 +773,8 @@ async fn create_single_parquet_file_prefetch() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=1994-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=1994-2994 " ); } @@ -343,10 +802,10 @@ async fn create_single_parquet_file_too_small_prefetch() { @r" RequestCountingObjectStore() Total Requests: 4 - - HEAD path=parquet_table.parquet - - GET (range) range=2494-2994 path=parquet_table.parquet - - GET (range) range=2264-2986 path=parquet_table.parquet - - GET (range) range=2124-2264 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=2494-2994 + - GET (ranges) path=parquet_table.parquet ranges=2264-2986 + - GET (ranges) path=parquet_table.parquet ranges=2124-2264 " ); } @@ -375,9 +834,9 @@ async fn create_single_parquet_file_small_prefetch() { @r" RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet - - GET (range) range=2254-2994 path=parquet_table.parquet - - GET (range) range=2124-2264 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=2254-2994 + - GET (ranges) path=parquet_table.parquet ranges=2124-2264 " ); } @@ -399,8 +858,8 @@ async fn create_single_parquet_file_no_prefetch() { @r" RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet - - GET (range) range=0-2994 path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=0-2994 " ); } @@ -420,7 +879,7 @@ async fn query_single_parquet_file() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=4-534,534-1064 - GET (ranges) path=parquet_table.parquet ranges=1064-1594,1594-2124 " @@ -444,7 +903,7 @@ async fn query_single_parquet_file_with_single_predicate() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 2 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 " ); @@ -468,7 +927,7 @@ async fn query_single_parquet_file_multi_row_groups_multiple_predicates() { ------- Object Store Request Summary ------- RequestCountingObjectStore() Total Requests: 3 - - HEAD path=parquet_table.parquet + - GET (opts) path=parquet_table.parquet head=true - GET (ranges) path=parquet_table.parquet ranges=4-421,421-534,534-951,951-1064 - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 " @@ -630,6 +1089,116 @@ impl Test { .await } + /// Register a single CSV file with six equal-length rows for byte-range + /// repartitioning tests. With a single file and `target_partitions=3`, the + /// repartitioner creates exactly 3 ranges. + async fn with_single_file_csv_for_range_test(self) -> Test { + let csv_data = "c1,c2,c3\n\ + 0.00001,1e-12,false\n\ + 0.00002,2e-12,false\n\ + 0.00003,3e-12,false\n\ + 0.00004,4e-12,false\n\ + 0.00005,5e-12,false\n\ + 0.00006,6e-12,false\n"; + self.with_bytes("/csv_range_table.csv", csv_data) + .await + .register_csv("csv_range_table", "/csv_range_table.csv") + .await + } + + /// Register a JSON (NDJSON) file at the given path + async fn register_json(self, table_name: &str, path: &str) -> Self { + let url = format!("mem://{path}"); + self.session_context + .register_json(table_name, url, JsonReadOptions::default()) + .await + .unwrap(); + self + } + + /// Register a partitioned JSON table at the given path + async fn register_partitioned_json(self, table_name: &str, path: &str) -> Self { + let file_format = Arc::new(JsonFormat::default()); + let options = ListingOptions::new(file_format); + + let url = format!("mem://{path}").parse().unwrap(); + let table_url = ListingTableUrl::try_new(url, None).unwrap(); + + let session_state = self.session_context.state(); + let mut config = ListingTableConfig::new(table_url).with_listing_options(options); + config = config + .infer_partitions_from_path(&session_state) + .await + .unwrap(); + config = config.infer_schema(&session_state).await.unwrap(); + + let table = Arc::new(ListingTable::try_new(config).unwrap()); + self.session_context + .register_table(table_name, table) + .unwrap(); + self + } + + /// Register a single NDJSON file with three columns and two rows named `json_table` + async fn with_single_file_json(self) -> Test { + let json_data = "{\"c1\":0.00001,\"c2\":5e-12,\"c3\":true}\n\ + {\"c1\":0.00002,\"c2\":4e-12,\"c3\":false}\n"; + self.with_bytes("/json_table.json", json_data) + .await + .register_json("json_table", "/json_table.json") + .await + } + + /// Register a single NDJSON file with six equal-length rows for byte-range + /// repartitioning tests. With a single file and `target_partitions=3`, the + /// repartitioner creates exactly 3 ranges. + async fn with_single_file_json_for_range_test(self) -> Test { + let json_data = r#"{"c1":0.00001,"c2":1e-12,"c3":true} +{"c1":0.00002,"c2":2e-12,"c3":true} +{"c1":0.00003,"c2":3e-12,"c3":true} +{"c1":0.00004,"c2":4e-12,"c3":true} +{"c1":0.00005,"c2":5e-12,"c3":true} +{"c1":0.00006,"c2":6e-12,"c3":true} +"#; + self.with_bytes("/json_range_table.json", json_data) + .await + .register_json("json_range_table", "/json_range_table.json") + .await + } + + /// Register three NDJSON files in a directory, called `json_table` + async fn with_multi_file_json(mut self) -> Test { + for i in 0..3 { + let json_data = format!( + "{{\"c1\":0.0000{i},\"c2\":{i}e-12,\"c3\":true}}\n\ + {{\"c1\":0.00003,\"c2\":5e-12,\"c3\":false}}\n" + ); + self = self + .with_bytes(&format!("/data/file_{i}.json"), json_data) + .await; + } + self.register_json("json_table", "/data/").await + } + + /// Register three NDJSON files in a partitioned directory structure, called + /// `json_table_partitioned` + async fn with_partitioned_json(mut self) -> Test { + for i in 1..4 { + let json_data = format!( + "{{\"d1\":0.0000{i},\"d2\":{i}e-12,\"d3\":true}}\n\ + {{\"d1\":0.00003,\"d2\":5e-12,\"d3\":false}}\n" + ); + self = self + .with_bytes( + &format!("/data/a={i}/b={}/c={}/file_{i}.json", i * 10, i * 100), + json_data, + ) + .await; + } + self.register_partitioned_json("json_table_partitioned", "/data/") + .await + } + /// Add a single parquet file that has two columns and two row groups named `parquet_table` /// /// Column "a": Int32 with values 0-100] in row group 1 @@ -645,7 +1214,7 @@ impl Test { let mut buffer = vec![]; let props = parquet::file::properties::WriterProperties::builder() - .set_max_row_group_size(100) + .set_max_row_group_row_count(Some(100)) .build(); let mut writer = parquet::arrow::ArrowWriter::try_new( &mut buffer, @@ -696,11 +1265,8 @@ impl Test { /// Details of individual requests made through the [`RequestCountingObjectStore`] #[derive(Clone, Debug)] enum RequestDetails { - Get { path: Path }, GetOpts { path: Path, get_options: GetOptions }, GetRanges { path: Path, ranges: Vec> }, - GetRange { path: Path, range: Range }, - Head { path: Path }, List { prefix: Option }, ListWithDelimiter { prefix: Option }, ListWithOffset { prefix: Option, offset: Path }, @@ -718,9 +1284,6 @@ fn display_range(range: &Range) -> impl Display + '_ { impl Display for RequestDetails { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { - RequestDetails::Get { path } => { - write!(f, "GET path={path}") - } RequestDetails::GetOpts { path, get_options } => { write!(f, "GET (opts) path={path}")?; if let Some(range) = &get_options.range { @@ -758,13 +1321,6 @@ impl Display for RequestDetails { } Ok(()) } - RequestDetails::GetRange { path, range } => { - let range = display_range(range); - write!(f, "GET (range) range={range} path={path}") - } - RequestDetails::Head { path } => { - write!(f, "HEAD path={path}") - } RequestDetails::List { prefix } => { write!(f, "LIST")?; if let Some(prefix) = prefix { @@ -837,7 +1393,7 @@ impl ObjectStore for RequestCountingObjectStore { _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn put_multipart_opts( @@ -845,15 +1401,7 @@ impl ObjectStore for RequestCountingObjectStore { _location: &Path, _opts: PutMultipartOptions, ) -> object_store::Result> { - Err(object_store::Error::NotImplemented) - } - - async fn get(&self, location: &Path) -> object_store::Result { - let result = self.inner.get(location).await?; - self.requests.lock().push(RequestDetails::Get { - path: location.to_owned(), - }); - Ok(result) + unimplemented!() } async fn get_opts( @@ -869,19 +1417,6 @@ impl ObjectStore for RequestCountingObjectStore { Ok(result) } - async fn get_range( - &self, - location: &Path, - range: Range, - ) -> object_store::Result { - let result = self.inner.get_range(location, range.clone()).await?; - self.requests.lock().push(RequestDetails::GetRange { - path: location.to_owned(), - range: range.clone(), - }); - Ok(result) - } - async fn get_ranges( &self, location: &Path, @@ -895,18 +1430,6 @@ impl ObjectStore for RequestCountingObjectStore { Ok(result) } - async fn head(&self, location: &Path) -> object_store::Result { - let result = self.inner.head(location).await?; - self.requests.lock().push(RequestDetails::Head { - path: location.to_owned(), - }); - Ok(result) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) - } - fn list( &self, prefix: Option<&Path>, @@ -942,15 +1465,19 @@ impl ObjectStore for RequestCountingObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } } diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs index b6f406e967509..e02364a0530cc 100644 --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -22,26 +22,25 @@ use datafusion::common::NullEquality; use datafusion::functions_aggregate::sum; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan; +use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::execution_plan::Boundedness; -use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_common::{exec_datafusion_err, DataFusionError, JoinType, ScalarValue}; +use datafusion_common::{DataFusionError, JoinType, ScalarValue, exec_datafusion_err}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{ - binary, col, lit, BinaryExpr, Column, Literal, + BinaryExpr, Column, Literal, binary, col, lit, }; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; use datafusion_physical_plan::coop::make_cooperative; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; @@ -64,13 +63,14 @@ use std::time::Duration; use tokio::runtime::{Handle, Runtime}; use tokio::select; -#[derive(Debug)] +#[derive(Debug, Clone)] struct RangeBatchGenerator { schema: SchemaRef, value_range: Range, boundedness: Boundedness, batch_size: usize, poll_count: usize, + original_range: Range, } impl std::fmt::Display for RangeBatchGenerator { @@ -110,6 +110,13 @@ impl LazyBatchGenerator for RangeBatchGenerator { RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; Ok(Some(batch)) } + + fn reset_state(&self) -> Arc> { + let mut new = self.clone(); + new.poll_count = 0; + new.value_range = new.original_range.clone(); + Arc::new(RwLock::new(new)) + } } fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { @@ -136,16 +143,17 @@ fn make_lazy_exec_with_range( }; // Instantiate the generator with the batch and limit - let gen = RangeBatchGenerator { + let batch_gen = RangeBatchGenerator { schema: Arc::clone(&schema), boundedness, - value_range: range, + value_range: range.clone(), batch_size: 8192, poll_count: 0, + original_range: range, }; // Wrap the generator in a trait object behind Arc> - let generator: Arc> = Arc::new(RwLock::new(gen)); + let generator: Arc> = Arc::new(RwLock::new(batch_gen)); // Create a LazyMemoryExec with one partition using our generator let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); @@ -170,7 +178,7 @@ async fn agg_no_grouping_yields( let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![Arc::new( AggregateExprBuilder::new( sum::sum_udaf(), @@ -204,7 +212,7 @@ async fn agg_grouping_yields( let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![], false), vec![Arc::new( AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) .schema(inf.schema()) @@ -225,6 +233,7 @@ async fn agg_grouped_topk_yields( #[values(false, true)] pretend_infinite: bool, ) -> Result<(), Box> { // build session + let session_ctx = SessionContext::new(); // set up a top-k aggregation @@ -240,6 +249,7 @@ async fn agg_grouped_topk_yields( vec![(group, "group".to_string())], vec![], vec![vec![false]], + false, ), vec![Arc::new( AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) @@ -251,7 +261,7 @@ async fn agg_grouped_topk_yields( inf.clone(), inf.schema(), )? - .with_limit(Some(100)), + .with_limit_options(Some(LimitOptions::new(100))), ); query_yields(aggr, session_ctx.task_ctx()).await @@ -415,10 +425,7 @@ async fn filter_reject_all_batches_yields( )); let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); - // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch - let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); - - query_yields(coalesced, session_ctx.task_ctx()).await + query_yields(filtered, session_ctx.task_ctx()).await } #[rstest] @@ -545,6 +552,7 @@ async fn interleave_then_aggregate_yields( vec![], // no GROUP BY columns vec![], // no GROUP BY expressions vec![], // no GROUP BY physical expressions + false, ), vec![Arc::new(aggregate_expr)], vec![None], // no “distinct” flags @@ -573,17 +581,18 @@ async fn join_yields( let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; - // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition - let coalesced_left = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); - let coalesced_right = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); - let part_left = Partitioning::Hash(left_keys, 1); let part_right = Partitioning::Hash(right_keys, 1); - let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); - let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); // Build an Inner HashJoinExec → left.value = right.value let join = Arc::new(HashJoinExec::try_new( @@ -598,6 +607,7 @@ async fn join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -621,17 +631,18 @@ async fn join_agg_yields( let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; - // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition - let coalesced_left = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); - let coalesced_right = - Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); - let part_left = Partitioning::Hash(left_keys, 1); let part_right = Partitioning::Hash(right_keys, 1); - let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); - let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); // Build an Inner HashJoinExec → left.value = right.value let join = Arc::new(HashJoinExec::try_new( @@ -646,6 +657,7 @@ async fn join_agg_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); // Project only one column (“value” from the left side) because we just want to sum that @@ -653,7 +665,7 @@ async fn join_agg_yields( let proj_expr = vec![ProjectionExpr::new( Arc::new(Column::new_with_schema("value", &input_schema)?) as _, - "value".to_string(), + "value", )]; let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); @@ -676,7 +688,7 @@ async fn join_agg_yields( let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![Arc::new(aggregate_expr)], vec![None], projection, @@ -711,6 +723,7 @@ async fn hash_join_yields( None, PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -742,9 +755,10 @@ async fn hash_join_without_repartition_and_no_agg( /* filter */ None, &JoinType::Inner, /* output64 */ None, - // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + // Using CollectLeft is fine—just avoid RepartitionExec's partitioned channels. PartitionMode::CollectLeft, NullEquality::NullEqualsNull, + false, )?); query_yields(join, session_ctx.task_ctx()).await @@ -753,7 +767,7 @@ async fn hash_join_without_repartition_and_no_agg( #[derive(Debug)] enum Yielded { ReadyOrPending, - Err(#[allow(dead_code)] DataFusionError), + Err(#[expect(dead_code)] DataFusionError), Timeout, } @@ -780,9 +794,9 @@ async fn stream_yields( let yielded = select! { result = join_handle => { match result { - Ok(Pending) => Yielded::ReadyOrPending, - Ok(Ready(Ok(_))) => Yielded::ReadyOrPending, - Ok(Ready(Err(e))) => Yielded::Err(e), + Ok(Poll::Pending) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Ok(_))) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Err(e))) => Yielded::Err(e), Err(_) => Yielded::Err(exec_datafusion_err!("join error")), } }, diff --git a/datafusion/core/tests/execution/datasource_split.rs b/datafusion/core/tests/execution/datasource_split.rs index 0b90c6f326168..370249cd8044e 100644 --- a/datafusion/core/tests/execution/datasource_split.rs +++ b/datafusion/core/tests/execution/datasource_split.rs @@ -22,7 +22,7 @@ use arrow::{ }; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_execution::TaskContext; -use datafusion_physical_plan::{common::collect, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, common::collect}; use std::sync::Arc; /// Helper function to create a memory source with the given batch size and collect all batches diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index ef2e263f2c467..3eaa3fb2ed5e6 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -20,7 +20,7 @@ use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion::datasource::{provider_as_source, ViewTable}; +use datafusion::datasource::{ViewTable, provider_as_source}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs index 8770b2a201051..f33ef87aa3023 100644 --- a/datafusion/core/tests/execution/mod.rs +++ b/datafusion/core/tests/execution/mod.rs @@ -18,3 +18,4 @@ mod coop; mod datasource_split; mod logical_plan; +mod register_arrow; diff --git a/datafusion/core/tests/execution/register_arrow.rs b/datafusion/core/tests/execution/register_arrow.rs new file mode 100644 index 0000000000000..4ce16dc0906c1 --- /dev/null +++ b/datafusion/core/tests/execution/register_arrow.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for register_arrow API + +use datafusion::{execution::options::ArrowReadOptions, prelude::*}; +use datafusion_common::Result; + +#[tokio::test] +async fn test_register_arrow_auto_detects_format() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_format", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_format", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let file_result = ctx.sql("SELECT * FROM file_format ORDER BY f0").await?; + let stream_result = ctx.sql("SELECT * FROM stream_format ORDER BY f0").await?; + + let file_batches = file_result.collect().await?; + let stream_batches = stream_result.collect().await?; + + assert_eq!(file_batches.len(), stream_batches.len()); + assert_eq!(file_batches[0].schema(), stream_batches[0].schema()); + + let file_rows: usize = file_batches.iter().map(|b| b.num_rows()).sum(); + let stream_rows: usize = stream_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(file_rows, stream_rows); + + Ok(()) +} + +#[tokio::test] +async fn test_register_arrow_join_file_and_stream() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_table", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_table", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let result = ctx + .sql( + "SELECT a.f0, a.f1, b.f0, b.f1 + FROM file_table a + JOIN stream_table b ON a.f0 = b.f0 + WHERE a.f0 <= 2 + ORDER BY a.f0", + ) + .await?; + let batches = result.collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + + Ok(()) +} diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 84e644480a4fd..19ff3933193de 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -16,17 +16,16 @@ // under the License. use arrow::array::{ - builder::{ListBuilder, StringBuilder}, ArrayRef, Int64Array, RecordBatch, StringArray, StructArray, + builder::{ListBuilder, StringBuilder}, }; use arrow::datatypes::{DataType, Field}; use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::expr::NullTreatment; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; @@ -36,6 +35,7 @@ use datafusion_optimizer::simplify_expressions::ExprSimplifier; use std::sync::{Arc, LazyLock}; mod parse_sql_expr; +#[expect(clippy::needless_pass_by_value)] mod simplification; #[test] @@ -342,20 +342,26 @@ fn test_create_physical_expr_nvl2() { #[tokio::test] async fn test_create_physical_expr_coercion() { - // create_physical_expr does apply type coercion and unwrapping in cast + // create_physical_expr applies type coercion (and can unwrap/fold + // literal casts). Comparison coercion prefers numeric types, so + // string/int comparisons cast the string side to the numeric type. // - // expect the cast on the literals - // compare string function to int `id = 1` - create_expr_test(col("id").eq(lit(1i32)), "id@0 = CAST(1 AS Utf8)"); - create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0"); - // compare int col to string literal `i = '202410'` - // Note this casts the column (not the field) - create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)"); - // however, when simplified the casts on i should removed - // https://github.com/apache/datafusion/issues/14944 - create_simplified_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_simplified_expr_test(lit("202410").eq(col("i")), "CAST(i@1 AS Utf8) = 202410"); + // string column vs int literal: id (Utf8) is cast to Int32 + create_expr_test(col("id").eq(lit(1i32)), "CAST(id@0 AS Int32) = 1"); + create_expr_test(lit(1i32).eq(col("id")), "1 = CAST(id@0 AS Int32)"); + // int column vs string literal: the string literal is cast to Int64 + create_expr_test(col("i").eq(lit("202410")), "i@1 = CAST(202410 AS Int64)"); + create_expr_test(lit("202410").eq(col("i")), "CAST(202410 AS Int64) = i@1"); + // The simplifier operates on the logical expression before type + // coercion adds the CAST, so the output is unchanged. + create_simplified_expr_test( + col("i").eq(lit("202410")), + "i@1 = CAST(202410 AS Int64)", + ); + create_simplified_expr_test( + lit("202410").eq(col("i")), + "i@1 = CAST(202410 AS Int64)", + ); } /// Evaluates the specified expr as an aggregate and compares the result to the @@ -384,6 +390,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. +#[expect(clippy::needless_pass_by_value)] fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { let batch = &TEST_BATCH; let df_schema = DFSchema::try_from(batch.schema()).unwrap(); @@ -420,9 +427,9 @@ fn create_simplified_expr_test(expr: Expr, expected_expr: &str) { let df_schema = DFSchema::try_from(batch.schema()).unwrap(); // Simplify the expression first - let props = ExecutionProps::new(); - let simplify_context = - SimplifyContext::new(&props).with_schema(df_schema.clone().into()); + let simplify_context = SimplifyContext::builder() + .with_schema(Arc::new(df_schema)) + .build(); let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10); let simplified = simplifier.simplify(expr).unwrap(); create_expr_test(simplified, expected_expr); diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 92c18204324f7..b0d8b3a349ae2 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -19,9 +19,9 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::DFSchema; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::Expr; use datafusion_expr::col; use datafusion_expr::lit; -use datafusion_expr::Expr; use datafusion_sql::unparser::Unparser; /// A schema like: /// diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 46c36c6abdacc..245aba66849ce 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -23,16 +23,16 @@ use arrow::array::types::IntervalDayTime; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, TimeZone, Utc}; -use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; -use datafusion_common::cast::as_int32_array; +use datafusion::{error::Result, prelude::*}; use datafusion_common::ScalarValue; +use datafusion_common::cast::as_int32_array; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, - ScalarUDF, Volatility, + Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Projection, + ScalarUDF, Volatility, table_scan, }; use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; @@ -40,50 +40,6 @@ use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpress use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; -/// In order to simplify expressions, DataFusion must have information -/// about the expressions. -/// -/// You can provide that information using DataFusion [DFSchema] -/// objects or from some other implementation -struct MyInfo { - /// The input schema - schema: DFSchemaRef, - - /// Execution specific details needed for constant evaluation such - /// as the current time for `now()` and [VariableProviders] - execution_props: ExecutionProps, -} - -impl SimplifyInfo for MyInfo { - fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(matches!( - expr.get_type(self.schema.as_ref())?, - DataType::Boolean - )) - } - - fn nullable(&self, expr: &Expr) -> Result { - expr.nullable(self.schema.as_ref()) - } - - fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - fn get_data_type(&self, expr: &Expr) -> Result { - expr.get_type(self.schema.as_ref()) - } -} - -impl From for MyInfo { - fn from(schema: DFSchemaRef) -> Self { - Self { - schema, - execution_props: ExecutionProps::new(), - } - } -} - /// A schema like: /// /// a: Int32 (possibly with nulls) @@ -132,14 +88,11 @@ fn test_evaluate_with_start_time( expected_expr: Expr, date_time: &DateTime, ) { - let execution_props = - ExecutionProps::new().with_query_execution_start_time(*date_time); - - let info: MyInfo = MyInfo { - schema: schema(), - execution_props, - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::builder() + .with_schema(schema()) + .with_query_execution_start_time(Some(*date_time)) + .build(); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -201,7 +154,10 @@ fn to_timestamp_expr(arg: impl Into) -> Expr { #[test] fn basic() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::builder() + .with_schema(schema()) + .with_query_execution_start_time(Some(Utc::now())) + .build(); // The `Expr` is a core concept in DataFusion, and DataFusion can // help simplify it. @@ -210,21 +166,21 @@ fn basic() { // optimize form `a < 5` automatically let expr = col("a").lt(lit(2i32) + lit(3i32)); - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, col("a").lt(lit(5i32))); } #[test] fn fold_and_simplify() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::builder().with_schema(schema()).build(); // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar")); // Since datafusion applies both simplification *and* rewriting // some expressions can be entirely simplified - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, lit(true)) } @@ -243,10 +199,10 @@ fn to_timestamp_expr_folded() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r#" Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8("2020-09-08T12:00:00+00:00")) TableScan: test - "### + "# ); Ok(()) } @@ -273,10 +229,10 @@ fn now_less_than_timestamp() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: Boolean(true) TableScan: test - "### + " ); Ok(()) } @@ -312,10 +268,10 @@ fn select_date_plus_interval() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") TableScan: test - "### + "# ); Ok(()) } @@ -334,10 +290,10 @@ fn simplify_project_scalar_fn() -> Result<()> { let actual = formatter.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: test.f AS power(test.f,Float64(1)) TableScan: test - "### + " ); Ok(()) } @@ -523,6 +479,72 @@ fn multiple_now() -> Result<()> { Ok(()) } +/// Unwraps an alias expression to get the inner expression +fn unrwap_aliases(expr: &Expr) -> &Expr { + match expr { + Expr::Alias(alias) => unrwap_aliases(&alias.expr), + expr => expr, + } +} + +/// Test that `now()` is simplified to a literal when execution start time is set, +/// but remains as an expression when no execution start time is available. +#[test] +fn now_simplification_with_and_without_start_time() { + let plan = LogicalPlanBuilder::empty(false) + .project(vec![now()]) + .unwrap() + .build() + .unwrap(); + + // Case 1: With execution start time set, now() should be simplified to a literal + { + let time = DateTime::::from_timestamp_nanos(123); + let ctx: OptimizerContext = + OptimizerContext::new().with_query_execution_start_time(time); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan.clone(), &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should be a literal timestamp + match simplified { + Expr::Literal(ScalarValue::TimestampNanosecond(Some(ts), _), _) => { + assert_eq!(*ts, time.timestamp_nanos_opt().unwrap()); + } + other => panic!("Expected timestamp literal, got: {other:?}"), + } + } + + // Case 2: Without execution start time, now() should remain as a function call + { + let ctx: OptimizerContext = + OptimizerContext::new().without_query_execution_start_time(); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan, &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should still be a now() function call + match simplified { + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + assert_eq!(func.name(), "now"); + } + other => panic!("Expected now() function call, got: {other:?}"), + } + } +} + // ------------------------------ // --- Simplifier tests ----- // ------------------------------ @@ -545,11 +567,10 @@ fn expr_test_schema() -> DFSchemaRef { } fn test_simplify(input_expr: Expr, expected_expr: Expr) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::builder() + .with_schema(expr_test_schema()) + .build(); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -564,11 +585,11 @@ fn test_simplify_with_cycle_count( expected_expr: Expr, expected_count: u32, ) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::builder() + .with_schema(expr_test_schema()) + .with_query_execution_start_time(Some(Utc::now())) + .build(); + let simplifier = ExprSimplifier::new(context); let (simplified_expr, count) = simplifier .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); diff --git a/datafusion/core/tests/extension_types/mod.rs b/datafusion/core/tests/extension_types/mod.rs new file mode 100644 index 0000000000000..bfe0c2e34927e --- /dev/null +++ b/datafusion/core/tests/extension_types/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod pretty_printing; diff --git a/datafusion/core/tests/extension_types/pretty_printing.rs b/datafusion/core/tests/extension_types/pretty_printing.rs new file mode 100644 index 0000000000000..c0796887b8b6e --- /dev/null +++ b/datafusion/core/tests/extension_types/pretty_printing.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{FixedSizeBinaryArray, RecordBatch}; +use arrow_schema::extension::Uuid; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::dataframe::DataFrame; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::SessionContext; +use datafusion_expr::registry::MemoryExtensionTypeRegistry; +use insta::assert_snapshot; +use std::sync::Arc; + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![uuid_field()])) +} + +fn uuid_field() -> Field { + Field::new("my_uuids", DataType::FixedSizeBinary(16), false).with_extension_type(Uuid) +} + +async fn create_test_table() -> Result { + let schema = test_schema(); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(FixedSizeBinaryArray::from(vec![ + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5, 6], + ]))], + )?; + + let state = SessionStateBuilder::default() + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_batch("test", batch)?; + + ctx.table("test").await +} + +#[tokio::test] +async fn test_pretty_print_extension_type_formatter() -> Result<()> { + let result = create_test_table().await?.to_string().await?; + + assert_snapshot!( + result, + @r" + +--------------------------------------+ + | my_uuids | + +--------------------------------------+ + | 00000000-0000-0000-0000-000000000000 | + | 00010203-0405-0607-0809-000102030506 | + +--------------------------------------+ + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 141a3f3b75586..3d99cc72fa590 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -22,21 +22,21 @@ mod unix_test { use std::fs::File; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use arrow::array::Array; use arrow::csv::ReaderBuilder; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; use datafusion_common::instant::Instant; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::SortExpr; use futures::StreamExt; @@ -44,7 +44,7 @@ mod unix_test { use nix::unistd; use tempfile::TempDir; use tokio::io::AsyncWriteExt; - use tokio::task::{spawn_blocking, JoinHandle}; + use tokio::task::{JoinHandle, spawn_blocking}; /// Makes a TableProvider for a fifo file fn fifo_table( @@ -94,7 +94,6 @@ mod unix_test { /// This function creates a writing task for the FIFO file. To verify /// incremental processing, it waits for a signal to continue writing after /// a certain number of lines are written. - #[allow(clippy::disallowed_methods)] fn create_writing_task( file_path: PathBuf, header: String, @@ -105,6 +104,7 @@ mod unix_test { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); // Spawn a new task to write to the FIFO file + #[expect(clippy::disallowed_methods)] tokio::spawn(async move { let mut file = tokio::fs::OpenOptions::new() .write(true) @@ -357,7 +357,7 @@ mod unix_test { (sink_fifo_path.clone(), sink_fifo_path.display()); // Spawn a new thread to read sink EXTERNAL TABLE. - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/fuzz.rs b/datafusion/core/tests/fuzz.rs index 92646e8b37636..5e94f12b5805d 100644 --- a/datafusion/core/tests/fuzz.rs +++ b/datafusion/core/tests/fuzz.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -/// Run all tests that are found in the `fuzz_cases` directory +/// Run all tests that are found in the `fuzz_cases` directory. +/// Fuzz tests are slow and gated behind the `extended_tests` feature. +/// Run with: cargo test --features extended_tests +#[cfg(feature = "extended_tests")] mod fuzz_cases; #[cfg(test)] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4e04da26f70b6..4726e7c4aca5c 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -24,37 +24,37 @@ use crate::fuzz_cases::aggregation_fuzzer::{ }; use arrow::array::{ - types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, StringArray, + types::Int64Type, }; use arrow::compute::concat_batches; use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; +use datafusion::datasource::MemTable; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::expressions::{Column, col, lit}; use datafusion_physical_plan::InputOrderMode; -use test_utils::{add_empty_batches, StringBatchGenerator}; +use test_utils::{StringBatchGenerator, add_empty_batches}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::metrics::MetricValue; -use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect, displayable}; use rand::rngs::StdRng; -use rand::{random, rng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, random, rng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -326,15 +326,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .unwrap(), ); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .schema(Arc::clone(&schema)) - .alias("sum1") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .build() + .map(Arc::new) + .unwrap(), + ]; let expr = group_by_columns .iter() .map(|elem| (col(elem, &schema).unwrap(), (*elem).to_string())) @@ -548,14 +547,14 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { type Node = Arc; fn f_down(&mut self, node: &'n Self::Node) -> Result { - if let Some(exec) = node.as_any().downcast_ref::() { + if let Some(exec) = node.downcast_ref::() { if self.expected_sort { assert!(matches!( exec.input_order_mode(), InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted )); } else { - assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); + assert_eq!(*exec.input_order_mode(), InputOrderMode::Linear); } } Ok(TreeNodeRecursion::Continue) @@ -650,7 +649,9 @@ pub(crate) fn assert_spill_count_metric( if expect_spill && spill_count == 0 { panic!("Expected spill but SpillCount metric not found or SpillCount was 0."); } else if !expect_spill && spill_count > 0 { - panic!("Expected no spill but found SpillCount metric with value greater than 0."); + panic!( + "Expected no spill but found SpillCount metric with value greater than 0." + ); } spill_count diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index fa8ea0b31c023..fe31098622c58 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; @@ -214,7 +214,7 @@ impl GeneratedSessionContextBuilder { /// The generated params for [`SessionContext`] #[derive(Debug)] -#[allow(dead_code)] +#[expect(dead_code)] pub struct SessionContextParams { batch_size: usize, target_partitions: usize, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index aaf2d1b9bad4f..e49cffa89b04e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -18,7 +18,7 @@ use arrow::array::RecordBatch; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort::sort_batch; use test_utils::stagger_batch; @@ -209,8 +209,8 @@ mod test { sort_keys_set: vec![vec!["b".to_string()]], }; - let mut gen = DatasetGenerator::new(config); - let datasets = gen.generate().unwrap(); + let mut data_gen = DatasetGenerator::new(config); + let datasets = data_gen.generate().unwrap(); // Should Generate 2 datasets assert_eq!(datasets.len(), 2); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 1a8ef278cc299..430762b1c28db 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -19,9 +19,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{internal_datafusion_err, Result}; +use datafusion_common::{Result, internal_datafusion_err}; use datafusion_common_runtime::JoinSet; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs index 766e2bedd74c2..7bb6177c31010 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -17,7 +17,7 @@ use std::{collections::HashSet, str::FromStr}; -use rand::{rng, seq::SliceRandom, Rng}; +use rand::{Rng, rng, seq::SliceRandom}; /// Random aggregate query builder /// @@ -182,13 +182,13 @@ impl QueryBuilder { /// Add max columns num in group by(default: 3), for example if it is set to 1, /// the generated sql will group by at most 1 column - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { self.max_group_by_columns = max_group_by_columns; self } - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { self.min_group_by_columns = min_group_by_columns; self @@ -202,7 +202,7 @@ impl QueryBuilder { } /// Add if also test the no grouping aggregation case(default: true) - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { self.no_grouping = no_grouping; self diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 3049631d4b3fe..92adda200d1a5 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::array::{cast::AsArray, Array, OffsetSizeTrait, RecordBatch}; +use arrow::array::{Array, OffsetSizeTrait, RecordBatch, cast::AsArray}; use datafusion::datasource::MemTable; use datafusion_common_runtime::JoinSet; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 171839b390ffa..a57095066ee12 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,19 +16,19 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, create_test_params, create_test_schema_2, + TestScalarUDF, create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, - is_table_same_after_sort, TestScalarUDF, + is_table_same_after_sort, }; use arrow::compute::SortOptions; -use datafusion_common::config::ConfigOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::equivalence::{ convert_to_orderings, convert_to_sort_exprs, }; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index a72a1558b2e41..2f67e211ce915 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -16,15 +16,15 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - apply_projection, create_random_schema, generate_table_for_eq_properties, - is_table_same_after_sort, TestScalarUDF, + TestScalarUDF, apply_projection, create_random_schema, + generate_table_for_eq_properties, is_table_same_after_sort, }; use arrow::compute::SortOptions; -use datafusion_common::config::ConfigOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 382c4da943219..1490eb08a0291 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -18,13 +18,13 @@ use std::sync::Arc; use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - TestScalarUDF, + TestScalarUDF, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, }; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index be35ddca8f02d..8350cafb215cb 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,26 +15,25 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn, SortOptions}; +use arrow::compute::{SortColumn, SortOptions, lexsort_to_indices, take_record_batch}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; +use datafusion_common::{Result, exec_err, internal_datafusion_err, plan_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_physical_expr::equivalence::{ - convert_to_orderings, EquivalenceClass, ProjectionMapping, + EquivalenceClass, ProjectionMapping, convert_to_orderings, }; use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_plan::expressions::{col, Column}; +use datafusion_physical_plan::expressions::{Column, col}; use itertools::izip; use rand::prelude::*; @@ -50,7 +49,7 @@ pub fn output_schema( let data_type = source.data_type(input_schema)?; let nullable = source.nullable(input_schema)?; for (target, _) in targets.iter() { - let Some(column) = target.as_any().downcast_ref::() else { + let Some(column) = target.downcast_ref::() else { return plan_err!("Expects to have column"); }; fields.push(Field::new(column.name(), data_type.clone(), nullable)); @@ -283,7 +282,7 @@ fn get_representative_arr( schema: SchemaRef, ) -> Option { for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); if let Some(res) = &existing_vec[idx] { return Some(Arc::clone(res)); @@ -371,7 +370,7 @@ pub fn generate_table_for_eq_properties( // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr.as_any().downcast_ref::().unwrap(); + let col = constant.expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; @@ -383,7 +382,7 @@ pub fn generate_table_for_eq_properties( let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering .iter() .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = generate_random_array(n_elem, n_distinct); ( @@ -409,7 +408,7 @@ pub fn generate_table_for_eq_properties( .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); schema_vec[idx] = Some(Arc::clone(&representative_array)); } @@ -531,9 +530,6 @@ impl TestScalarUDF { } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "test-scalar-udf" } diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index e8ff1ccf06704..fdb2934817bc5 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -38,8 +38,11 @@ use datafusion::physical_plan::joins::{ }; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{NullEquality, ScalarValue}; -use datafusion_physical_expr::expressions::Literal; +use datafusion_execution::TaskContext; +use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr::expressions::Literal; use itertools::Itertools; use rand::Rng; @@ -91,484 +94,564 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { #[tokio::test] async fn test_inner_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Inner, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Inner, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Left, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Left, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Right, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Right, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Full, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Full, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[NljHj, HjSmj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } // todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support #[tokio::test] async fn test_right_mark_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_mark_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Inner, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Inner, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Left, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Left, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Right, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Right, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Full, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Full, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[NljHj, HjSmj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } // todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support #[tokio::test] async fn test_right_mark_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_mark_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } type JoinFilterBuilder = Box, Arc) -> JoinFilter>; @@ -769,6 +852,7 @@ impl JoinFuzzTestCase { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) @@ -841,7 +925,9 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); + println!( + "Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}" + ); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -892,10 +978,18 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); + let err_msg_rowcnt = format!( + "NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}" + ); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); + if nlj_rows == 0 && hj_rows == 0 { + // both joins returned no rows, skip content comparison + continue; + } - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); + let err_msg_contents = format!( + "NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}" + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -913,10 +1007,16 @@ impl JoinFuzzTestCase { } if join_tests.contains(&HjSmj) { - let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + let err_msg_row_cnt = format!( + "HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", + &batch_size + ); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); - let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); + let err_msg_contents = format!( + "SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", + &batch_size + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows if smj_rows > 0 || hj_rows > 0 { @@ -990,7 +1090,7 @@ impl JoinFuzzTestCase { /// Files can be of different sizes /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` /// for test debugging purposes - #[allow(dead_code)] + #[expect(dead_code)] async fn load_partitioned_batches_from_parquet( dir: &str, ) -> std::io::Result> { @@ -1028,10 +1128,142 @@ impl JoinFuzzTestCase { } } +/// Fuzz test: compare SMJ (with spilling) against HJ (no spill) for filtered +/// outer joins under memory pressure. This exercises the deferred filtering + +/// spill read-back path that unit tests can't easily cover with random data. +#[tokio::test] +async fn test_filtered_join_spill_fuzz() { + let join_types = [JoinType::Left, JoinType::Right, JoinType::Full]; + + let runtime_spill = RuntimeEnvBuilder::new() + .with_memory_limit(4096, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc() + .unwrap(); + + for join_type in &join_types { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + let input1 = make_staggered_batches_i32(1000, left_extra); + let input2 = make_staggered_batches_i32(1000, right_extra); + + let schema1 = input1[0].schema(); + let schema2 = input2[0].schema(); + let filter = col_lt_col_filter(schema1.clone(), schema2.clone()); + + let on = vec![ + ( + Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, + ), + ]; + + for batch_size in [2, 49, 100] { + let session_config = SessionConfig::new().with_batch_size(batch_size); + + // HJ baseline (no memory limit) + let left_hj = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&input1), + schema1.clone(), + None, + ) + .unwrap(); + let right_hj = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&input2), + schema2.clone(), + None, + ) + .unwrap(); + let hj = Arc::new( + HashJoinExec::try_new( + left_hj, + right_hj, + on.clone(), + Some(filter.clone()), + join_type, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + let ctx_hj = SessionContext::new_with_config(session_config.clone()); + let hj_collected = collect(hj, ctx_hj.task_ctx()).await.unwrap(); + + // SMJ with spilling + let left_smj = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&input1), + schema1.clone(), + None, + ) + .unwrap(); + let right_smj = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&input2), + schema2.clone(), + None, + ) + .unwrap(); + let smj = Arc::new( + SortMergeJoinExec::try_new( + left_smj, + right_smj, + on.clone(), + Some(filter.clone()), + *join_type, + vec![SortOptions::default(); on.len()], + NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + let task_ctx_spill = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime_spill)), + ); + let smj_collected = collect(smj, task_ctx_spill).await.unwrap(); + + let hj_rows: usize = hj_collected.iter().map(|b| b.num_rows()).sum(); + let smj_rows: usize = smj_collected.iter().map(|b| b.num_rows()).sum(); + + assert_eq!( + hj_rows, smj_rows, + "Row count mismatch for {join_type:?} batch_size={batch_size} \ + left_extra={left_extra} right_extra={right_extra}: \ + HJ={hj_rows} SMJ={smj_rows}" + ); + + if hj_rows > 0 { + let hj_fmt = + pretty_format_batches(&hj_collected).unwrap().to_string(); + let smj_fmt = + pretty_format_batches(&smj_collected).unwrap().to_string(); + + let mut hj_sorted: Vec<&str> = hj_fmt.trim().lines().collect(); + hj_sorted.sort_unstable(); + let mut smj_sorted: Vec<&str> = smj_fmt.trim().lines().collect(); + smj_sorted.sort_unstable(); + + assert_eq!( + hj_sorted, smj_sorted, + "Content mismatch for {join_type:?} batch_size={batch_size} \ + left_extra={left_extra} right_extra={right_extra}" + ); + } + } + } + } +} + /// Return randomly sized record batches with: /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns -fn make_staggered_batches_i32(len: usize) -> Vec { +fn make_staggered_batches_i32(len: usize, with_extra_column: bool) -> Vec { let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; @@ -1044,17 +1276,28 @@ fn make_staggered_batches_i32(len: usize) -> Vec { input12.sort_unstable(); let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); - let input3 = Int32Array::from_iter_values(input3); + let input3 = Int32Array::from_iter(input3.into_iter().map(|v| { + // ~10% NULLs in filter column to exercise NULL filter handling + if rng.random_range(0..10) == 0 { + None + } else { + Some(v) + } + })); let input4 = Int32Array::from_iter_values(input4); - // split into several record batches - let batch = RecordBatch::try_from_iter(vec![ + let mut columns = vec![ ("a", Arc::new(input1) as ArrayRef), ("b", Arc::new(input2) as ArrayRef), ("x", Arc::new(input3) as ArrayRef), - ("y", Arc::new(input4) as ArrayRef), - ]) - .unwrap(); + ]; + + if with_extra_column { + columns.push(("y", Arc::new(input4) as ArrayRef)); + } + + // split into several record batches + let batch = RecordBatch::try_from_iter(columns).unwrap(); // use a random number generator to pick a random sized output stagger_batch_with_seed(batch, 42) @@ -1070,7 +1313,10 @@ fn rand_bytes(rng: &mut R, min: usize, max: usize) -> Vec { /// Return randomly sized record batches with: /// two sorted binary columns 'a', 'b' (lexicographically) as join columns /// two random binary columns 'x', 'y' as other columns -fn make_staggered_batches_binary(len: usize) -> Vec { +fn make_staggered_batches_binary( + len: usize, + with_extra_column: bool, +) -> Vec { let mut rng = rand::rng(); // produce (a,b) pairs then sort lexicographically so SMJ has naturally sorted keys @@ -1088,13 +1334,17 @@ fn make_staggered_batches_binary(len: usize) -> Vec { let x = BinaryArray::from_iter_values(input3.iter()); let y = BinaryArray::from_iter_values(input4.iter()); - let batch = RecordBatch::try_from_iter(vec![ + let mut columns = vec![ ("a", Arc::new(a) as ArrayRef), ("b", Arc::new(b) as ArrayRef), ("x", Arc::new(x) as ArrayRef), - ("y", Arc::new(y) as ArrayRef), - ]) - .unwrap(); + ]; + + if with_extra_column { + columns.push(("y", Arc::new(y) as ArrayRef)); + } + + let batch = RecordBatch::try_from_iter(columns).unwrap(); // preserve your existing randomized partitioning stagger_batch_with_seed(batch, 42) diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 4c5ebf0402414..1c5741e7a21b3 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use std::sync::Arc; use test_utils::stagger_batch; diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index b92dec64e3f19..59430a98cc4b4 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -27,7 +27,7 @@ use arrow::{ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 9e2fd170f7f0c..edb53df382c62 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -15,20 +15,26 @@ // specific language governing permissions and limitations // under the License. +#[expect(clippy::needless_pass_by_value)] mod aggregate_fuzz; mod distinct_count_string_fuzz; +#[expect(clippy::needless_pass_by_value)] mod join_fuzz; mod merge_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_query_fuzz; mod topk_filter_pushdown; mod aggregation_fuzzer; +#[expect(clippy::needless_pass_by_value)] mod equivalence; mod pruning; mod limit_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/once_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs index 49e2caaa7417c..403e377a690e2 100644 --- a/datafusion/core/tests/fuzz_cases/once_exec.rs +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -17,13 +17,13 @@ use arrow_schema::SchemaRef; use datafusion_common::internal_datafusion_err; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; -use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Mutex}; @@ -32,7 +32,7 @@ use std::sync::{Arc, Mutex}; pub struct OnceExec { /// the results to send back stream: Mutex>, - cache: PlanProperties, + cache: Arc, } impl Debug for OnceExec { @@ -46,7 +46,7 @@ impl OnceExec { let cache = Self::compute_properties(stream.schema()); Self { stream: Mutex::new(Some(stream)), - cache, + cache: Arc::new(cache), } } @@ -79,11 +79,7 @@ impl ExecutionPlan for OnceExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -110,4 +106,20 @@ impl ExecutionPlan for OnceExec { stream.ok_or_else(|| internal_datafusion_err!("Stream already consumed")) } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion_physical_plan::PhysicalExpr, + ) -> datafusion_common::Result, + ) -> datafusion_common::Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index f8bd4dbc1a768..8ce5207f91190 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -29,9 +29,11 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::{collect, filter::FilterExec, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect, filter::FilterExec}; use itertools::Itertools; -use object_store::{memory::InMemory, path::Path, ObjectStore, PutPayload}; +use object_store::{ + ObjectStore, ObjectStoreExt, PutPayload, memory::InMemory, path::Path, +}; use parquet::{ arrow::ArrowWriter, file::properties::{EnabledStatistics, WriterProperties}, @@ -276,13 +278,12 @@ async fn execute_with_predicate( ctx: &SessionContext, ) -> Vec { let parquet_source = if prune_stats { - ParquetSource::default().with_predicate(predicate.clone()) + ParquetSource::new(schema.clone()).with_predicate(predicate.clone()) } else { - ParquetSource::default() + ParquetSource::new(schema.clone()) }; let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("memory://").unwrap(), - schema.clone(), Arc::new(parquet_source), ) .with_file_group( diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index 45dba5f7864b1..22b145f5095a7 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -19,23 +19,23 @@ use std::sync::Arc; use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ - ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, - Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType, + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, }; use arrow_schema::{ - DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, - DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; -use rand::{rng, rngs::StdRng, Rng, SeedableRng}; +use datafusion_common::{Result, arrow_datafusion_err}; +use rand::{Rng, SeedableRng, rng, rngs::StdRng}; use test_utils::array_gen::{ BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator, diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 28d28a6622a76..0d8a066d432dd 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::{ - array::{as_string_array, ArrayRef, Int32Array, StringArray}, + array::{ArrayRef, Int32Array, StringArray, as_string_array}, compute::SortOptions, record_batch::RecordBatch, }; @@ -28,7 +28,7 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; use datafusion_execution::memory_pool::GreedyMemoryPool; diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 99b20790fc46b..a1f38f161d6ea 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -20,34 +20,33 @@ mod sp_repartition_fuzz_tests { use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; - use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; + use arrow::compute::{SortColumn, SortOptions, concat_batches, lexsort}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ - collect, + ExecutionPlan, Partitioning, collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, repartition::RepartitionExec, sorts::sort_preserving_merge::SortPreservingMergeExec, sorts::streaming_merge::StreamingMergeBuilder, stream::RecordBatchStreamAdapter, - ExecutionPlan, Partitioning, }; use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::ConstExpr; use datafusion_physical_expr::equivalence::{ EquivalenceClass, EquivalenceProperties, }; - use datafusion_physical_expr::expressions::{col, Column}; - use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; use itertools::izip; - use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom}; // Generate a schema which consists of 6 columns (a, b, c, d, e, f) fn create_test_schema() -> Result { @@ -119,7 +118,7 @@ mod sp_repartition_fuzz_tests { schema: SchemaRef, ) -> Option { for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); if let Some(res) = &existing_vec[idx] { return Some(res.clone()); @@ -150,7 +149,7 @@ mod sp_repartition_fuzz_tests { // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr.as_any().downcast_ref::().unwrap(); + let col = constant.expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -162,7 +161,7 @@ mod sp_repartition_fuzz_tests { let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering .iter() .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = generate_random_array(n_elem, n_distinct); ( @@ -188,7 +187,7 @@ mod sp_repartition_fuzz_tests { .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); + let col = expr.downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); schema_vec[idx] = Some(representative_array.clone()); } @@ -302,7 +301,7 @@ mod sp_repartition_fuzz_tests { let mut handles = Vec::new(); for seed in seed_start..seed_end { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_sort_preserving_repartition_test( make_staggered_batches::(n_row, n_distinct, seed as u64), is_first_roundrobin, diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 2ce7db3ea4bc7..376306f3e0659 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -24,24 +24,22 @@ use arrow::array::RecordBatch; use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{instant::Instant, Result}; +use datafusion_common::{Result, human_readable_size, instant::Instant}; use datafusion_execution::disk_manager::DiskManagerBuilder; -use datafusion_execution::memory_pool::{ - human_readable_size, MemoryPool, UnboundedMemoryPool, -}; +use datafusion_execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; use std::time::Duration; use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; -use rand::prelude::IndexedRandom; use rand::Rng; -use rand::{rngs::StdRng, SeedableRng}; +use rand::prelude::IndexedRandom; +use rand::{SeedableRng, rngs::StdRng}; use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; use super::aggregation_fuzzer::ColumnDescr; -use super::record_batch_generator::{get_supported_types_columns, RecordBatchGenerator}; +use super::record_batch_generator::{RecordBatchGenerator, get_supported_types_columns}; /// Entry point for executing the sort query fuzzer. /// @@ -177,16 +175,16 @@ impl SortQueryFuzzer { n_round: usize, n_query: usize, ) -> bool { - if let Some(time_limit) = self.time_limit { - if Instant::now().duration_since(start_time) > time_limit { - println!( - "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", - n_round * self.queries_per_round + n_query, - self.config_variations_per_query, - n_round - ); - return true; - } + if let Some(time_limit) = self.time_limit + && Instant::now().duration_since(start_time) > time_limit + { + println!( + "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", + n_round * self.queries_per_round + n_query, + self.config_variations_per_query, + n_round + ); + return true; } false } diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index 6c1bd316cdd39..d401557e966d6 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -27,18 +27,18 @@ use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch} use arrow_schema::{DataType, Field, Schema}; use datafusion::common::Result; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; -use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_common::units::{KB, MB}; use datafusion_execution::memory_pool::{ FairSpillPool, MemoryConsumer, MemoryReservation, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -80,9 +80,9 @@ async fn test_sort_with_limited_memory() -> Result<()> { let total_spill_files_size = spill_count * record_batch_size; assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); Ok(()) } @@ -126,8 +126,8 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> } #[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -164,8 +164,8 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c } #[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -278,9 +278,11 @@ async fn run_sort_test_with_limited_memory( let string_item_size = record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); RecordBatch::try_new( Arc::clone(&schema), @@ -356,16 +358,16 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<() let total_spill_files_size = spill_count * record_batch_size; assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); Ok(()) } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -398,8 +400,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -432,8 +434,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -466,8 +468,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -536,9 +538,11 @@ async fn run_test_aggregate_with_high_cardinality( let string_item_size = record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); RecordBatch::try_new( Arc::clone(&schema), diff --git a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs index 7f994daeaa58c..d14afaf1b3267 100644 --- a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs +++ b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs @@ -31,7 +31,7 @@ use datafusion_execution::object_store::ObjectStoreUrl; use itertools::Itertools; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; use parquet::arrow::ArrowWriter; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 65a41d39d3c54..82b6d0e4e9d89 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -18,24 +18,24 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, StringArray}; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::{SortOptions, concat_batches}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, + BoundedWindowAggExec, WindowAggExec, create_window_expr, schema_add_window_field, }; -use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; -use datafusion::physical_plan::{collect, InputOrderMode}; +use datafusion::physical_plan::{InputOrderMode, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -445,14 +445,14 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { - if !args.is_empty() { - // Do type coercion first argument - let a = args[0].clone(); - let dt = a.return_field(schema.as_ref()).unwrap(); - let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); - } + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn + && !args.is_empty() + { + // Do type coercion first argument + let a = args[0].clone(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_udf(&[dt], udf.as_ref()).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } (window_fn.clone(), args, (*fn_name).to_string()) @@ -569,10 +569,11 @@ fn convert_bound_to_current_row_if_applicable( ) { match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { - if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.random_range(0..2) == 0 { - *bound = WindowFrameBound::CurrentRow; - } + if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) + && value == &zero + && rng.random_range(0..2) == 0 + { + *bound = WindowFrameBound::CurrentRow; } } _ => {} @@ -588,7 +589,7 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, Sorted); + let is_linear = search_mode != Sorted; let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); @@ -644,10 +645,8 @@ async fn run_window_test( ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. - if is_linear { - if let Some(ordering) = LexOrdering::new(sort_keys) { - exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; - } + if is_linear && let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -699,7 +698,9 @@ async fn run_window_test( // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + let err_msg = format!( + "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}" + ); // Below check makes sure that, streaming execution generates more chunks than the bulk execution. // Since algorithms and operators works on sliding windows in the streaming execution. // However, in the current test setup for some random generated window frame clauses: It is not guaranteed @@ -731,8 +732,12 @@ async fn run_window_test( .enumerate() { if !usual_line.eq(running_line) { - println!("Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}"); - println!("--------usual_formatted_sorted----------------running_formatted_sorted--------"); + println!( + "Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}" + ); + println!( + "--------usual_formatted_sorted----------------running_formatted_sorted--------" + ); for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index c9f33f6fdf0f4..9fd60cd1f06f3 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -73,7 +73,7 @@ mod config_field { #[test] fn test_macro() { #[derive(Debug)] - #[allow(dead_code)] + #[expect(dead_code)] struct E; impl std::fmt::Display for E { @@ -84,7 +84,8 @@ mod config_field { impl std::error::Error for E {} - #[allow(dead_code)] + #[expect(dead_code)] + #[derive(Default)] struct S; impl std::str::FromStr for S { diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs index e1d5f1b1ab198..bf04123fff7fa 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs @@ -21,14 +21,10 @@ //! This file is organized as: //! - Test runners that spawn individual test processes //! - Test cases that contain the actual validation logic -use log::info; -use std::sync::Once; use std::{process::Command, str}; use crate::memory_limit::memory_limit_validation::utils; -static INIT: Once = Once::new(); - // =========================================================================== // Test runners: // Runners are split into multiple tests to run in parallel @@ -69,49 +65,16 @@ fn sort_with_mem_limit_2_cols_2_runner() { spawn_test_process("sort_with_mem_limit_2_cols_2"); } -/// `spawn_test_process` might trigger multiple recompilations and the test binary -/// size might grow indefinitely. This initializer ensures recompilation is only done -/// once and the target size is bounded. -/// -/// TODO: This is a hack, can be cleaned up if we have a better way to let multiple -/// test cases run in different processes (instead of different threads by default) -fn init_once() { - INIT.call_once(|| { - let _ = Command::new("cargo") - .arg("test") - .arg("--no-run") - .arg("--package") - .arg("datafusion") - .arg("--test") - .arg("core_integration") - .arg("--features") - .arg("extended_tests") - .env("DATAFUSION_TEST_MEM_LIMIT_VALIDATION", "1") - .output() - .expect("Failed to execute test command"); - }); -} - -/// Helper function that executes a test in a separate process with the required environment -/// variable set. Memory limit validation tasks need to measure memory resident set -/// size (RSS), so they must run in a separate process. +/// Helper function that executes a test in a separate process with the required +/// environment variable set. Re-invokes the current test binary directly, +/// avoiding cargo overhead and recompilation. fn spawn_test_process(test: &str) { - init_once(); - let test_path = format!("memory_limit::memory_limit_validation::sort_mem_validation::{test}"); - info!("Running test: {test_path}"); - - // Run the test command - let output = Command::new("cargo") - .arg("test") - .arg("--package") - .arg("datafusion") - .arg("--test") - .arg("core_integration") - .arg("--features") - .arg("extended_tests") - .arg("--") + + let exe = std::env::current_exe().expect("Failed to get test binary path"); + + let output = Command::new(exe) .arg(&test_path) .arg("--exact") .arg("--nocapture") @@ -119,12 +82,9 @@ fn spawn_test_process(test: &str) { .output() .expect("Failed to execute test command"); - // Convert output to strings let stdout = str::from_utf8(&output.stdout).unwrap_or(""); let stderr = str::from_utf8(&output.stderr).unwrap_or(""); - info!("{stdout}"); - assert!( output.status.success(), "Test '{}' failed with status: {}\nstdout:\n{}\nstderr:\n{}", diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs index 7b157b707a6de..2c9fae20c8606 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs @@ -16,16 +16,14 @@ // under the License. use datafusion_common_runtime::SpawnedTask; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use sysinfo::{ProcessRefreshKind, ProcessesToUpdate, System}; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_execution::{ - memory_pool::{human_readable_size, FairSpillPool}, - runtime_env::RuntimeEnvBuilder, -}; +use datafusion_common::human_readable_size; +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; /// Measures the maximum RSS (in bytes) during the execution of an async task. RSS /// will be sampled every 7ms. @@ -40,7 +38,7 @@ use datafusion_execution::{ async fn measure_max_rss(f: F) -> (T, usize) where F: FnOnce() -> Fut, - Fut: std::future::Future, + Fut: Future, { // Initialize system information let mut system = System::new_all(); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 5d8a1d24181cb..90459960c5561 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -17,13 +17,13 @@ //! This module contains tests for limiting memory at runtime in DataFusion -use std::any::Any; use std::num::NonZeroUsize; use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; mod repartition_mem_limit; +mod union_nullable_spill; use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray}; use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; @@ -39,19 +39,19 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; -use datafusion_common::{assert_contains, Result}; +use datafusion_catalog::streaming::StreamingTable; +use datafusion_common::{Result, assert_contains}; +use datafusion_execution::TaskContext; use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_plan::collect as collect_batches; use datafusion_physical_plan::common::collect; use datafusion_physical_plan::spill::get_record_batch_memory_size; @@ -212,6 +212,7 @@ async fn sort_merge_join_spill() { .with_config(config) .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) + .with_expected_success() .run() .await } @@ -602,11 +603,16 @@ async fn test_disk_spill_limit_reached() -> Result<()> { .await .unwrap(); - let err = df.collect().await.unwrap_err(); - assert_contains!( - err.to_string(), - "The used disk space during the spilling process has exceeded the allowable limit" - ); + let error_message = df.collect().await.unwrap_err().to_string(); + for expected in [ + "The used disk space during the spilling process has exceeded the allowable limit", + "datafusion.runtime.max_temp_directory_size", + ] { + assert!( + error_message.contains(expected), + "'{expected}' is not contained by '{error_message}'" + ); + } Ok(()) } @@ -977,11 +983,13 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![[ - PhysicalSortExpr::new(col("a", &schema).unwrap(), options), - PhysicalSortExpr::new(col("b", &schema).unwrap(), options), - ] - .into()]; + let sort_information = vec![ + [ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into(), + ]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -1057,7 +1065,7 @@ fn make_dict_batches() -> Vec { let batch_size = 50; let mut i = 0; - let gen = std::iter::from_fn(move || { + let batch_gen = std::iter::from_fn(move || { // create values like // 0000000001 // 0000000002 @@ -1080,7 +1088,7 @@ fn make_dict_batches() -> Vec { let num_batches = 5; - let batches: Vec<_> = gen.take(num_batches).collect(); + let batches: Vec<_> = batch_gen.take(num_batches).collect(); batches.iter().enumerate().for_each(|(i, batch)| { println!("Dict batch[{i}] size is: {}", batch.get_array_memory_size()); @@ -1136,10 +1144,6 @@ impl SortedTableProvider { #[async_trait] impl TableProvider for SortedTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs index a7af2f01d1cc9..27bcd33926e99 100644 --- a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs +++ b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs @@ -25,7 +25,7 @@ use datafusion::{ use datafusion_catalog::MemTable; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_physical_plan::{repartition::RepartitionExec, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlanProperties, repartition::RepartitionExec}; use futures::TryStreamExt; use itertools::Itertools; @@ -45,11 +45,14 @@ async fn test_repartition_memory_limit() { .with_batch_size(32) .with_target_partitions(2); let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); - let batches = vec![RecordBatch::try_from_iter(vec![( - "c1", - Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) as ArrayRef, - )]) - .unwrap()]; + let batches = vec![ + RecordBatch::try_from_iter(vec![( + "c1", + Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) + as ArrayRef, + )]) + .unwrap(), + ]; let table = Arc::new(MemTable::try_new(batches[0].schema(), vec![batches]).unwrap()); ctx.register_table("t", table).unwrap(); let plan = ctx @@ -71,7 +74,7 @@ async fn test_repartition_memory_limit() { let mut metrics = None; Arc::clone(&plan) .transform_down(|node| { - if node.as_any().is::() { + if node.is::() { metrics = node.metrics(); } Ok(Transformed::no(node)) diff --git a/datafusion/core/tests/memory_limit/union_nullable_spill.rs b/datafusion/core/tests/memory_limit/union_nullable_spill.rs new file mode 100644 index 0000000000000..c5ef2387d3cdc --- /dev/null +++ b/datafusion/core/tests/memory_limit/union_nullable_spill.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Array, Int64Array, RecordBatch}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::sort_batch; +use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_plan::{ExecutionPlan, Partitioning}; +use futures::StreamExt; + +const NUM_BATCHES: usize = 200; +const ROWS_PER_BATCH: usize = 10; + +fn non_nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])) +} + +fn nullable_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])) +} + +fn non_nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + RecordBatch::try_new( + non_nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])), + ], + ) + .unwrap() + }) + .collect() +} + +fn nullable_batches() -> Vec { + (0..NUM_BATCHES) + .map(|i| { + let start = (i * ROWS_PER_BATCH) as i64; + let keys: Vec = (start..start + ROWS_PER_BATCH as i64).collect(); + let vals: Vec> = (0..ROWS_PER_BATCH) + .map(|j| if j % 3 == 1 { None } else { Some(j as i64) }) + .collect(); + RecordBatch::try_new( + nullable_schema(), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(vals)), + ], + ) + .unwrap() + }) + .collect() +} + +fn build_task_ctx(pool_size: usize) -> Arc { + let session_config = SessionConfig::new().with_batch_size(2); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(pool_size))) + .build_arc() + .unwrap(); + Arc::new( + datafusion_execution::TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ) +} + +/// Exercises spilling through UnionExec -> RepartitionExec where union children +/// have mismatched nullability (one child's `val` is non-nullable, the other's +/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill. +/// +/// UnionExec returns child streams without schema coercion, so batches from +/// different children carry different per-field nullability into the shared +/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable) +/// schema — not the first batch's schema — so readback batches are valid. +/// +/// Otherwise, sort_batch will panic with +/// `Column 'val' is declared as non-nullable but contains null values` +#[tokio::test] +async fn test_sort_union_repartition_spill_mixed_nullability() { + let non_nullable_exec = MemorySourceConfig::try_new_exec( + &[non_nullable_batches()], + non_nullable_schema(), + None, + ) + .unwrap(); + + let nullable_exec = + MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None) + .unwrap(); + + let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap(); + assert!(union_exec.schema().field(1).is_nullable()); + + let repartition = Arc::new( + RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + + let task_ctx = build_task_ctx(200); + let mut stream = repartition.execute(0, task_ctx).unwrap(); + + let sort_expr = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("key", &nullable_schema()).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + + let mut total_rows = 0usize; + let mut total_nulls = 0usize; + while let Some(result) = stream.next().await { + let batch = result.unwrap(); + + let batch = sort_batch(&batch, &sort_expr, None).unwrap(); + + total_rows += batch.num_rows(); + total_nulls += batch.column(1).null_count(); + } + + assert_eq!( + total_rows, + NUM_BATCHES * ROWS_PER_BATCH * 2, + "All rows from both UNION branches should be present" + ); + assert!( + total_nulls > 0, + "Expected some null values in output (i.e. nullable batches were processed)" + ); +} diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..a461c6f6c5962 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -19,7 +19,6 @@ //! datafusion-functions crate. use insta::assert_snapshot; -use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -27,17 +26,16 @@ use arrow::datatypes::{ DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; +use datafusion_common::tree_node::TransformedResult; +use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, HigherOrderUDF, LogicalPlan, Operator, + ScalarUDF, TableSource, WindowUDF, col, lit, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -45,6 +43,7 @@ use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use chrono::DateTime; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees; use datafusion_functions::datetime; #[cfg(test)] @@ -217,6 +216,10 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } + fn get_higher_order_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } @@ -237,6 +240,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn higher_order_function_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } @@ -251,10 +258,6 @@ struct MyTableSource { } impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -304,8 +307,6 @@ fn test_inequalities_non_null_bounded() { ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt(lit(0)), false), @@ -337,7 +338,7 @@ fn test_inequalities_non_null_bounded() { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").gt(lit(2)), @@ -348,16 +349,20 @@ fn test_inequalities_non_null_bounded() { col("x").not_between(lit(3), lit(10)), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } -fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) -where +fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], +) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -365,9 +370,11 @@ where ); } } -fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { +fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" diff --git a/datafusion/core/tests/parquet/content_defined_chunking.rs b/datafusion/core/tests/parquet/content_defined_chunking.rs new file mode 100644 index 0000000000000..6a98ded1bd4cf --- /dev/null +++ b/datafusion/core/tests/parquet/content_defined_chunking.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for parquet content-defined chunking (CDC). +//! +//! These tests verify that CDC options are correctly wired through to the +//! parquet writer by inspecting file metadata (compressed sizes, page +//! boundaries) on the written files. + +use arrow::array::{AsArray, Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::config::{CdcOptions, TableParquetOptions}; +use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::ArrowReaderMetadata; +use parquet::file::properties::WriterProperties; +use std::fs::File; +use std::sync::Arc; +use tempfile::NamedTempFile; + +/// Create a RecordBatch with enough data to exercise CDC chunking. +fn make_test_batch(num_rows: usize) -> RecordBatch { + let ids: Vec = (0..num_rows as i32).collect(); + // ~100 bytes per row to generate enough data for CDC page splits + let payloads: Vec = (0..num_rows) + .map(|i| format!("row-{i:06}-payload-{}", "x".repeat(80))) + .collect(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("payload", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(StringArray::from(payloads)), + ], + ) + .unwrap() +} + +/// Build WriterProperties from TableParquetOptions, exercising the same +/// code path that DataFusion's parquet sink uses. +fn writer_props( + opts: &mut TableParquetOptions, + schema: &Arc, +) -> WriterProperties { + opts.arrow_schema(schema); + parquet::file::properties::WriterPropertiesBuilder::try_from( + opts as &TableParquetOptions, + ) + .unwrap() + .build() +} + +/// Write a batch to a temp parquet file and return the file handle. +fn write_parquet_file(batch: &RecordBatch, props: WriterProperties) -> NamedTempFile { + let tmp = tempfile::Builder::new() + .suffix(".parquet") + .tempfile() + .unwrap(); + let mut writer = + ArrowWriter::try_new(tmp.reopen().unwrap(), batch.schema(), Some(props)).unwrap(); + writer.write(batch).unwrap(); + writer.close().unwrap(); + tmp +} + +/// Read parquet metadata from a file. +fn read_metadata(file: &NamedTempFile) -> parquet::file::metadata::ParquetMetaData { + let f = File::open(file.path()).unwrap(); + let reader_meta = ArrowReaderMetadata::load(&f, Default::default()).unwrap(); + reader_meta.metadata().as_ref().clone() +} + +/// Write parquet with CDC enabled, read it back via DataFusion, and verify +/// the data round-trips correctly. +#[tokio::test] +async fn cdc_data_round_trip() { + let batch = make_test_batch(5000); + + let mut opts = TableParquetOptions::default(); + opts.global.use_content_defined_chunking = Some(CdcOptions::default()); + let props = writer_props(&mut opts, &batch.schema()); + + let tmp = write_parquet_file(&batch, props); + + // Read back via DataFusion and verify row count + let ctx = SessionContext::new(); + ctx.register_parquet( + "data", + tmp.path().to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await + .unwrap(); + + let result = ctx + .sql("SELECT COUNT(*), MIN(id), MAX(id) FROM data") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let row = &result[0]; + let count = row.column(0).as_primitive::().value(0); + let min_id = row.column(1).as_primitive::().value(0); + let max_id = row.column(2).as_primitive::().value(0); + + assert_eq!(count, 5000); + assert_eq!(min_id, 0); + assert_eq!(max_id, 4999); +} + +/// Verify that CDC options are reflected in the parquet file metadata. +/// With small chunk sizes, CDC should produce different page boundaries +/// compared to default (no CDC) writing. +#[tokio::test] +async fn cdc_affects_page_boundaries() { + let batch = make_test_batch(5000); + + // Write WITHOUT CDC + let mut no_cdc_opts = TableParquetOptions::default(); + let no_cdc_file = + write_parquet_file(&batch, writer_props(&mut no_cdc_opts, &batch.schema())); + let no_cdc_meta = read_metadata(&no_cdc_file); + + // Write WITH CDC using small chunk sizes to maximize effect + let mut cdc_opts = TableParquetOptions::default(); + cdc_opts.global.use_content_defined_chunking = Some(CdcOptions { + min_chunk_size: 512, + max_chunk_size: 2048, + norm_level: 0, + }); + let cdc_file = + write_parquet_file(&batch, writer_props(&mut cdc_opts, &batch.schema())); + let cdc_meta = read_metadata(&cdc_file); + + // Both files should have the same number of rows + assert_eq!( + no_cdc_meta.file_metadata().num_rows(), + cdc_meta.file_metadata().num_rows(), + ); + + // Compare the uncompressed sizes of columns across all row groups. + // CDC with small chunk sizes should produce different page boundaries. + let no_cdc_sizes: Vec = no_cdc_meta + .row_groups() + .iter() + .flat_map(|rg| rg.columns().iter().map(|c| c.uncompressed_size())) + .collect(); + + let cdc_sizes: Vec = cdc_meta + .row_groups() + .iter() + .flat_map(|rg| rg.columns().iter().map(|c| c.uncompressed_size())) + .collect(); + + assert_ne!( + no_cdc_sizes, cdc_sizes, + "CDC with small chunk sizes should produce different page layouts \ + than default writing. no_cdc={no_cdc_sizes:?}, cdc={cdc_sizes:?}" + ); +} diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 3a1f06656236c..ae11fa9a11334 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -20,7 +20,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::SystemTime; -use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; +use arrow::array::{ArrayRef, Int8Array, Int64Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use datafusion::datasource::listing::PartitionedFile; @@ -31,8 +31,8 @@ use datafusion::datasource::physical_plan::{ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::SessionContext; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use bytes::Bytes; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -43,10 +43,10 @@ use futures::{FutureExt, TryFutureExt}; use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; +use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::AsyncFileReader; -use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; @@ -69,18 +69,14 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { store_parquet_in_memory(vec![batch]).await; let file_group = parquet_files_meta .into_iter() - .map(|meta| PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: Some(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))), - metadata_size_hint: None, + .map(|meta| { + PartitionedFile::new_from_meta(meta) + .with_extensions(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))) }) .collect(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(file_schema.clone()) // prepare the scan .with_parquet_file_reader_factory(Arc::new( InMemoryParquetFileReaderFactory(Arc::clone(&in_memory_object_store)), @@ -89,7 +85,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { let base_config = FileScanConfigBuilder::new( // just any url that doesn't point to in memory object store ObjectStoreUrl::local_filesystem(), - file_schema, source, ) .with_file_group(file_group) diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs index 09b93f06ce85d..12bdb600c2ac9 100644 --- a/datafusion/core/tests/parquet/encryption.rs +++ b/datafusion/core/tests/parquet/encryption.rs @@ -25,11 +25,11 @@ use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::listing::ListingOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_common::config::{EncryptionFactoryOptions, TableParquetOptions}; -use datafusion_common::{assert_batches_sorted_eq, exec_datafusion_err, DataFusionError}; +use datafusion_common::{DataFusionError, assert_batches_sorted_eq, exec_datafusion_err}; use datafusion_datasource_parquet::ParquetFormat; use datafusion_execution::parquet_encryption::EncryptionFactory; -use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::file::column_crypto_metadata::ColumnCryptoMetaData; @@ -54,6 +54,7 @@ async fn read_parquet_test_data<'a, T: Into>( .unwrap() } +#[expect(clippy::needless_pass_by_value)] pub fn write_batches( path: PathBuf, props: WriterProperties, @@ -114,8 +115,8 @@ async fn round_trip_encryption() { // Read encrypted parquet let ctx: SessionContext = SessionContext::new(); - let options = - ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + let options = ParquetReadOptions::default() + .file_decryption_properties((&decrypt).try_into().unwrap()); let encrypted_batches = read_parquet_test_data( tempfile.into_os_string().into_string().unwrap(), diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs new file mode 100644 index 0000000000000..fd70d74a9140c --- /dev/null +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -0,0 +1,1128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Int32Array, Int64Array, LargeListArray, ListArray, + RecordBatch, StringArray, StructArray, record_batch, +}; +use arrow::buffer::OffsetBuffer; +use arrow::compute::concat_batches; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use bytes::{BufMut, BytesMut}; +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, +}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_datasource::ListingTableUrl; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{self, Column}; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, + PhysicalExprAdapterFactory, +}; +use object_store::{ObjectStore, ObjectStoreExt, memory::InMemory, path::Path}; +use parquet::arrow::ArrowWriter; + +async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store.put(&Path::from(path), data.into()).await.unwrap(); +} + +#[derive(Debug, Clone, Copy)] +enum NestedListKind { + List, + LargeList, +} + +impl NestedListKind { + fn field_data_type(self, item_field: Arc) -> DataType { + match self { + Self::List => DataType::List(item_field), + Self::LargeList => DataType::LargeList(item_field), + } + } + + fn array( + self, + item_field: Arc, + lengths: Vec, + values: ArrayRef, + ) -> ArrayRef { + match self { + Self::List => Arc::new(ListArray::new( + item_field, + OffsetBuffer::::from_lengths(lengths), + values, + None, + )), + Self::LargeList => Arc::new(LargeListArray::new( + item_field, + OffsetBuffer::::from_lengths(lengths), + values, + None, + )), + } + } + + fn name(self) -> &'static str { + match self { + Self::List => "list", + Self::LargeList => "large_list", + } + } +} + +#[derive(Debug)] +// Fixture row for one nested struct element inside the `messages` list column. +struct NestedMessageRow<'a> { + id: i32, + name: &'a str, + chain: Option<&'a str>, + ignored: Option, +} + +fn message_fields( + chain_type: DataType, + chain_nullable: bool, + include_chain: bool, + include_ignored: bool, +) -> Fields { + let mut fields = vec![ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + ]; + if include_chain { + fields.push(Arc::new(Field::new("chain", chain_type, chain_nullable))); + } + if include_ignored { + fields.push(Arc::new(Field::new("ignored", DataType::Int32, true))); + } + fields.into() +} + +// Helper to construct the target message schema for struct evolution tests. +// The schema always has id (Int64), name (Utf8), and chain with parameterized type. +fn target_message_fields(chain_type: DataType, chain_nullable: bool) -> Fields { + vec![ + Arc::new(Field::new("id", DataType::Int64, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(Field::new("chain", chain_type, chain_nullable)), + ] + .into() +} + +// Helper to build message columns in canonical order (id, name, chain, ignored) +// based on which optional fields are present in the schema. +fn build_message_columns( + id_array: &ArrayRef, + name_array: &ArrayRef, + chain_vec: &[Option<&str>], + ignored_array: &ArrayRef, + fields: &Fields, +) -> Vec { + let mut columns = vec![Arc::clone(id_array), Arc::clone(name_array)]; + + for field in fields.iter().skip(2) { + match field.name().as_str() { + "chain" => { + let chain_array = match field.data_type() { + DataType::Utf8 => { + Arc::new(StringArray::from(chain_vec.to_vec())) as ArrayRef + } + DataType::Struct(chain_fields) => { + let chain_struct = StructArray::new( + chain_fields.clone(), + vec![Arc::new(StringArray::from(chain_vec.to_vec())) + as ArrayRef], + None, + ); + Arc::new(chain_struct) as ArrayRef + } + other => panic!("unexpected chain field type: {other:?}"), + }; + columns.push(chain_array); + } + "ignored" => columns.push(Arc::clone(ignored_array)), + _ => {} + } + } + columns +} + +fn nested_messages_batch( + kind: NestedListKind, + row_id: i32, + messages: &[NestedMessageRow<'_>], + fields: &Fields, +) -> RecordBatch { + let item_field = Arc::new(Field::new("item", DataType::Struct(fields.clone()), true)); + + let (ids_vec, names_vec, chain_vec, ignored_vec) = messages.iter().fold( + ( + Vec::with_capacity(messages.len()), + Vec::with_capacity(messages.len()), + Vec::with_capacity(messages.len()), + Vec::with_capacity(messages.len()), + ), + |(mut ids, mut names, mut chains, mut ignoreds), msg| { + ids.push(msg.id); + names.push(Some(msg.name)); + chains.push(msg.chain); + ignoreds.push(msg.ignored); + (ids, names, chains, ignoreds) + }, + ); + + // Build all arrays once + let id_array = Arc::new(Int32Array::from(ids_vec)) as ArrayRef; + let name_array = Arc::new(StringArray::from(names_vec)) as ArrayRef; + let ignored_array = Arc::new(Int32Array::from(ignored_vec)) as ArrayRef; + + // Build columns in canonical order (id, name, chain, ignored) based on field schema + let columns = + build_message_columns(&id_array, &name_array, &chain_vec, &ignored_array, fields); + + let struct_array = StructArray::new(fields.clone(), columns, None); + + // Compute the message data type first, then move item_field into kind.array() + let message_data_type = kind.field_data_type(item_field.clone()); + let messages_array = + kind.array(item_field, vec![messages.len()], Arc::new(struct_array)); + let schema = Arc::new(Schema::new(vec![ + Field::new("row_id", DataType::Int32, false), + Field::new("messages", message_data_type, true), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![row_id])) as ArrayRef, + messages_array, + ], + ) + .unwrap() +} + +async fn register_memory_listing_table( + ctx: &SessionContext, + store: Arc, + base_path: &str, + table_schema: SchemaRef, +) { + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse(base_path).unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); +} + +fn test_context() -> SessionContext { + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + SessionContext::new_with_config(cfg) +} + +fn nested_list_table_schema( + kind: NestedListKind, + target_message_fields: Fields, +) -> SchemaRef { + let target_item = Arc::new(Field::new( + "item", + DataType::Struct(target_message_fields), + true, + )); + Arc::new(Schema::new(vec![ + Field::new("row_id", DataType::Int32, false), + Field::new("messages", kind.field_data_type(target_item), true), + ])) +} + +// Helper to extract message values from a nested list column. +// Returns the values at indices 0 and 1 from either a ListArray or LargeListArray. +fn extract_nested_list_values( + kind: NestedListKind, + column: &ArrayRef, +) -> (ArrayRef, ArrayRef) { + match kind { + NestedListKind::List => { + let list = column + .as_any() + .downcast_ref::() + .expect("messages should be a ListArray"); + (list.value(0), list.value(1)) + } + NestedListKind::LargeList => { + let list = column + .as_any() + .downcast_ref::() + .expect("messages should be a LargeListArray"); + (list.value(0), list.value(1)) + } + } +} + +// Helper to set up a nested list test fixture. +// Creates an in-memory store, writes the provided batches to parquet files, +// creates a SessionContext, and registers the resulting table. +// Returns the prepared context ready for queries. +async fn setup_nested_list_test( + kind: NestedListKind, + prefix_base: &str, + batches: Vec<(String, RecordBatch)>, + table_schema: SchemaRef, +) -> SessionContext { + let store = Arc::new(InMemory::new()) as Arc; + let prefix = format!("{}_{}", kind.name(), prefix_base); + + for (filename, batch) in batches { + write_parquet(batch, Arc::clone(&store), &format!("{prefix}/{filename}")).await; + } + + let ctx = test_context(); + register_memory_listing_table( + &ctx, + Arc::clone(&store), + &format!("memory:///{prefix}/"), + table_schema, + ) + .await; + + ctx +} + +async fn assert_nested_list_struct_schema_evolution(kind: NestedListKind) -> Result<()> { + // old.parquet shape: messages item struct has only (id, name), no `chain`. + let old_batch = nested_messages_batch( + kind, + 1, + &[ + NestedMessageRow { + id: 10, + name: "alpha", + chain: None, + ignored: None, + }, + NestedMessageRow { + id: 20, + name: "beta", + chain: None, + ignored: None, + }, + ], + &message_fields(DataType::Utf8, true, false, false), + ); + + // new.parquet shape: messages item struct adds nullable `chain` and extra `ignored`. + let new_batch = nested_messages_batch( + kind, + 2, + &[NestedMessageRow { + id: 30, + name: "gamma", + chain: Some("eth"), + ignored: Some(99), + }], + &message_fields(DataType::Utf8, true, true, true), + ); + + // Logical table schema expects evolved shape (id, name, nullable `chain`) and + // should ignore source-only `ignored` during reads. + let table_schema = + nested_list_table_schema(kind, target_message_fields(DataType::Utf8, true)); + + let ctx = setup_nested_list_test( + kind, + "struct_evolution", + vec![ + ("old.parquet".to_string(), old_batch), + ("new.parquet".to_string(), new_batch), + ], + table_schema, + ) + .await; + + let select_all = ctx + .sql("SELECT * FROM t ORDER BY row_id") + .await? + .collect() + .await?; + let all_rows = concat_batches(&select_all[0].schema(), &select_all)?; + + let row_ids = all_rows + .column(0) + .as_any() + .downcast_ref::() + .expect("row_id should be Int32"); + assert_eq!(row_ids.values(), &[1, 2]); + + let (messages0, messages1) = extract_nested_list_values(kind, all_rows.column(1)); + + let messages0 = messages0 + .as_any() + .downcast_ref::() + .expect("messages[0] should be a StructArray"); + let old_ids = messages0 + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(old_ids.values(), &[10, 20]); + let old_chain = messages0 + .column_by_name("chain") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(old_chain.iter().collect::>(), vec![None, None]); + + let messages1 = messages1 + .as_any() + .downcast_ref::() + .expect("messages[1] should be a StructArray"); + assert!( + messages1.column_by_name("ignored").is_none(), + "extra source fields should not appear in the logical schema" + ); + let new_chain = messages1 + .column_by_name("chain") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(new_chain.iter().collect::>(), vec![Some("eth")]); + + let projected = ctx + .sql( + "SELECT row_id, get_field(messages[1], 'id') AS msg_id, \ + get_field(messages[1], 'chain') AS chain \ + FROM t ORDER BY row_id", + ) + .await? + .collect() + .await?; + + #[rustfmt::skip] + let expected = [ + "+--------+--------+-------+", + "| row_id | msg_id | chain |", + "+--------+--------+-------+", + "| 1 | 10 | |", + "| 2 | 30 | eth |", + "+--------+--------+-------+", + ]; + assert_batches_eq!(expected, &projected); + + Ok(()) +} + +// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with +// the default value for the field type: +// - Int64 columns are filled with `1` +// - Utf8 columns are filled with `'b'` +#[derive(Debug)] +struct CustomPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(CustomPhysicalExprAdapter { + logical_file_schema: Arc::clone(&logical_file_schema), + physical_file_schema: Arc::clone(&physical_file_schema), + inner: Arc::new(DefaultPhysicalExprAdapter::new( + logical_file_schema, + physical_file_schema, + )), + })) + } +} + +#[derive(Debug, Clone)] +struct CustomPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + expr = expr + .transform(|expr| { + if let Some(column) = expr.downcast_ref::() { + let field_name = column.name(); + if self + .physical_file_schema + .field_with_name(field_name) + .ok() + .is_none() + { + let field = self + .logical_file_schema + .field_with_name(field_name) + .map_err(|_| { + DataFusionError::Plan(format!( + "Field '{field_name}' not found in logical file schema", + )) + })?; + // If the field does not exist, create a default value expression + // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests + let default_value = match field.data_type() { + DataType::Int64 => ScalarValue::Int64(Some(1)), + DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), + _ => unimplemented!( + "Unsupported data type: {}", + field.data_type() + ), + }; + return Ok(Transformed::yes(Arc::new( + expressions::Literal::new(default_value), + ))); + } + } + + Ok(Transformed::no(expr)) + }) + .data()?; + self.inner.rewrite(expr) + } +} + +#[tokio::test] +async fn test_custom_schema_adapter_and_custom_expression_adapter() { + let batch = + record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch, store.clone(), path).await; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), + ])); + + let mut cfg = SessionConfig::new() + // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + assert!( + !ctx.state() + .config_mut() + .options_mut() + .execution + .collect_statistics + ); + assert!(!ctx.state().config().collect_statistics()); + + // Test with DefaultPhysicalExprAdapterFactory - missing columns are filled with NULL + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Test with a custom physical expr adapter + // PhysicalExprAdapterFactory now handles both predicates AND projections + // CustomPhysicalExprAdapterFactory fills missing columns with 'b' for Utf8 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") + .await + .unwrap() + .collect() + .await + .unwrap(); + // With CustomPhysicalExprAdapterFactory, missing column c2 is filled with 'b' + // in both the predicate (c2 = 'b' becomes 'b' = 'b' -> true) and the projection + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| b | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); +} + +/// Test demonstrating how to implement a custom PhysicalExprAdapterFactory +/// that fills missing columns with non-null default values. +/// +/// PhysicalExprAdapterFactory rewrites expressions to use literals for +/// missing columns, handling schema evolution efficiently at planning time. +#[tokio::test] +async fn test_physical_expr_adapter_with_non_null_defaults() { + // File only has c1 column + let batch = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "defaults_test.parquet").await; + + // Table schema has additional columns c2 (Utf8) and c3 (Int64) that don't exist in file + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), // type differs from file (Int32 vs Int64) + Field::new("c2", DataType::Utf8, true), // missing from file + Field::new("c3", DataType::Int64, true), // missing from file + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // CustomPhysicalExprAdapterFactory fills: + // - missing Utf8 columns with 'b' + // - missing Int64 columns with 1 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + // Query all columns - missing columns should have default values + let batches = ctx + .sql("SELECT c1, c2, c3 FROM t ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // c1 is cast from Int32 to Int64, c2 defaults to 'b', c3 defaults to 1 + let expected = [ + "+----+----+----+", + "| c1 | c2 | c3 |", + "+----+----+----+", + "| 10 | b | 1 |", + "| 20 | b | 1 |", + "| 30 | b | 1 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work with default values + // c3 = 1 should match all rows since default is 1 + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + // c3 = 999 should match no rows + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 999") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_struct_schema_evolution_projection_and_filter() -> Result<()> { + use std::collections::HashMap; + + // Physical struct: {id: Int32, name: Utf8} + let physical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + ] + .into(); + + let struct_array = StructArray::new( + physical_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef, + ], + None, + ); + + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(physical_struct_fields), + true, + )])); + + let batch = + RecordBatch::try_new(Arc::clone(&physical_schema), vec![Arc::new(struct_array)])?; + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "struct_evolution.parquet").await; + + // Logical struct: {id: Int64?, name: Utf8?, extra: Boolean?} + metadata + let logical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int64, true)), + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(Field::new("extra", DataType::Boolean, true).with_metadata( + HashMap::from([("nested_meta".to_string(), "1".to_string())]), + )), + ] + .into(); + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Struct(logical_struct_fields), false) + .with_metadata(HashMap::from([("top_meta".to_string(), "1".to_string())])), + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT s FROM t") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(batches.len(), 1); + + // Verify top-level metadata propagation + let output_schema = batches[0].schema(); + let s_field = output_schema.field_with_name("s").unwrap(); + assert_eq!( + s_field.metadata().get("top_meta").map(String::as_str), + Some("1") + ); + + // Verify nested struct type/field propagation + values + let s_array = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("expected struct array"); + + let id_array = s_array + .column_by_name("id") + .expect("id column") + .as_any() + .downcast_ref::() + .expect("id should be cast to Int64"); + assert_eq!(id_array.values(), &[1, 2, 3]); + + let extra_array = s_array.column_by_name("extra").expect("extra column"); + assert_eq!(extra_array.null_count(), 3); + + // Verify nested field metadata propagation + let extra_field = match s_field.data_type() { + DataType::Struct(fields) => fields + .iter() + .find(|f| f.name() == "extra") + .expect("extra field"), + other => panic!("expected struct type for s, got {other:?}"), + }; + assert_eq!( + extra_field + .metadata() + .get("nested_meta") + .map(String::as_str), + Some("1") + ); + + // Smoke test: filtering on a missing nested field evaluates correctly + let filtered = ctx + .sql("SELECT get_field(s, 'extra') AS extra FROM t WHERE get_field(s, 'extra') IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].num_rows(), 3); + let extra = filtered[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("extra should be a boolean array"); + assert_eq!(extra.null_count(), 3); + + Ok(()) +} + +/// Macro to generate paired test functions for List and LargeList variants. +/// Expands to two `#[tokio::test]` functions with the specified names. +macro_rules! test_struct_schema_evolution_pair { + ( + list: $list_test:ident, + large_list: $large_list_test:ident, + fn: $assertion_fn:path $(, args: $($arg:expr),+)? + ) => { + #[tokio::test] + async fn $list_test() { + $assertion_fn(NestedListKind::List $(, $($arg),+)?).await; + } + + #[tokio::test] + async fn $large_list_test() { + $assertion_fn(NestedListKind::LargeList $(, $($arg),+)?).await; + } + }; + ( + list: $list_test:ident, + large_list: $large_list_test:ident, + fn_result: $assertion_fn:path + ) => { + #[tokio::test] + async fn $list_test() -> Result<()> { + $assertion_fn(NestedListKind::List).await + } + + #[tokio::test] + async fn $large_list_test() -> Result<()> { + $assertion_fn(NestedListKind::LargeList).await + } + }; +} + +test_struct_schema_evolution_pair!( + list: test_list_struct_schema_evolution_end_to_end, + large_list: test_large_list_struct_schema_evolution_end_to_end, + fn_result: assert_nested_list_struct_schema_evolution +); + +async fn assert_nested_list_struct_schema_evolution_errors( + kind: NestedListKind, + chain_type: DataType, + chain_nullable: bool, + expected_error: &str, +) { + let batch = nested_messages_batch( + kind, + 1, + &[NestedMessageRow { + id: 10, + name: "alpha", + chain: Some("eth"), + ignored: None, + }], + &message_fields(DataType::Utf8, true, true, false), + ); + + let table_schema = + nested_list_table_schema(kind, target_message_fields(chain_type, chain_nullable)); + + let ctx = setup_nested_list_test( + kind, + "struct_evolution_error", + vec![("data.parquet".to_string(), batch)], + table_schema, + ) + .await; + + let err = ctx + .sql("SELECT * FROM t") + .await + .unwrap() + .collect() + .await + .unwrap_err(); + assert!( + err.to_string().contains(expected_error), + "expected error to contain '{expected_error}', got: {err}" + ); +} + +async fn assert_non_nullable_missing_chain_field_fails(kind: NestedListKind) { + assert_nested_list_struct_schema_evolution_errors( + kind, + DataType::Utf8, + false, + "non-nullable", + ) + .await; +} + +async fn assert_incompatible_chain_field_fails(kind: NestedListKind) { + assert_nested_list_struct_schema_evolution_errors( + kind, + incompatible_chain_type(), + true, + "Cannot cast struct field 'chain'", + ) + .await; +} + +fn incompatible_chain_type() -> DataType { + DataType::Struct(vec![Arc::new(Field::new("value", DataType::Utf8, true))].into()) +} + +test_struct_schema_evolution_pair!( + list: test_list_struct_schema_evolution_non_nullable_missing_field_fails, + large_list: test_large_list_struct_schema_evolution_non_nullable_missing_field_fails, + fn: assert_non_nullable_missing_chain_field_fails +); + +test_struct_schema_evolution_pair!( + list: test_list_struct_schema_evolution_incompatible_field_fails, + large_list: test_large_list_struct_schema_evolution_incompatible_field_fails, + fn: assert_incompatible_chain_field_fails +); + +/// Test demonstrating that a single PhysicalExprAdapterFactory instance can be +/// reused across multiple ListingTable instances. +/// +/// This addresses the concern: "This is important for ListingTable. A test for +/// ListingTable would add assurance that the functionality is retained [i.e. we +/// can re-use a PhysicalExprAdapterFactory]" +#[tokio::test] +async fn test_physical_expr_adapter_factory_reuse_across_tables() { + // Create two different parquet files with different schemas + // File 1: has column c1 only + let batch1 = record_batch!(("c1", Int32, [1, 2, 3])).unwrap(); + // File 2: has column c1 only but different data + let batch2 = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + + // Write files to different paths + write_parquet(batch1, store.clone(), "table1/data.parquet").await; + write_parquet(batch2, store.clone(), "table2/data.parquet").await; + + // Table schema has additional columns that don't exist in files + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), // missing from files + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // Create ONE factory instance wrapped in Arc - this will be REUSED + let factory: Arc = + Arc::new(CustomPhysicalExprAdapterFactory); + + // Create ListingTable 1 using the shared factory + let listing_table_config1 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table1/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Clone the Arc, not create new factory + + let table1 = ListingTable::try_new(listing_table_config1).unwrap(); + ctx.register_table("t1", Arc::new(table1)).unwrap(); + + // Create ListingTable 2 using the SAME factory instance + let listing_table_config2 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table2/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Reuse same factory + + let table2 = ListingTable::try_new(listing_table_config2).unwrap(); + ctx.register_table("t2", Arc::new(table2)).unwrap(); + + // Verify table 1 works correctly with the shared factory + // CustomPhysicalExprAdapterFactory fills missing Utf8 columns with 'b' + let batches = ctx + .sql("SELECT c1, c2 FROM t1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | b |", + "| 2 | b |", + "| 3 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify table 2 also works correctly with the SAME shared factory + let batches = ctx + .sql("SELECT c1, c2 FROM t2 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 10 | b |", + "| 20 | b |", + "| 30 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work on both tables with the shared factory + let batches = ctx + .sql("SELECT c1 FROM t1 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 1 |", + "| 2 |", + "| 3 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = ctx + .sql("SELECT c1 FROM t2 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 5135f956852c3..9ff8137687c95 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -21,7 +21,7 @@ use std::path::Path; use std::sync::Arc; use crate::parquet::utils::MetricsFinder; -use crate::parquet::{create_data_batch, Scenario}; +use crate::parquet::{Scenario, create_data_batch}; use arrow::datatypes::SchemaRef; use arrow::util::pretty::pretty_format_batches; @@ -29,17 +29,17 @@ use datafusion::common::Result; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::prelude::SessionContext; -use datafusion_common::{assert_contains, DFSchema}; +use datafusion_common::{DFSchema, assert_contains}; use datafusion_datasource_parquet::{ParquetAccessPlan, RowGroupAccess}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{col, lit, Expr}; -use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::properties::WriterProperties; use tempfile::NamedTempFile; @@ -257,7 +257,10 @@ async fn bad_selection() { .await .unwrap_err(); let err_string = err.to_string(); - assert_contains!(&err_string, "Internal error: Invalid ParquetAccessPlan Selection. Row group 0 has 5 rows but selection only specifies 4 rows"); + assert_contains!( + &err_string, + "Row group 0 has 5 rows but selection only specifies 4 rows." + ); } /// Return a RowSelection of 1 rows from a row group of 5 rows @@ -355,11 +358,11 @@ impl TestFull { let source = if let Some(predicate) = predicate { let df_schema = DFSchema::try_from(schema.clone())?; let predicate = ctx.create_physical_expr(predicate, &df_schema)?; - Arc::new(ParquetSource::default().with_predicate(predicate)) + Arc::new(ParquetSource::new(schema.clone()).with_predicate(predicate)) } else { - Arc::new(ParquetSource::default()) + Arc::new(ParquetSource::new(schema.clone())) }; - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) + let config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); @@ -406,7 +409,7 @@ fn get_test_data() -> TestData { .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .build(); let batches = create_data_batch(scenario); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 64ee92eda2545..84396be8a6a67 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -18,31 +18,30 @@ use std::fs; use std::sync::Arc; +use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; -use datafusion_common::stats::Precision; use datafusion_common::DFSchema; +use datafusion_common::stats::Precision; +use datafusion_execution::cache::DefaultListFilesCache; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit::{ - DefaultFileStatisticsCache, DefaultListFilesCache, -}; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion::datasource::physical_plan::FileScanConfig; use datafusion_common::config::ConfigOptions; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::filter::FilterExec; use tempfile::tempdir; #[tokio::test] @@ -89,7 +88,7 @@ async fn check_stats_precision_with_filter_pushdown() { .unwrap(); assert!( - optimized_exec.as_any().is::(), + optimized_exec.is::(), "Sanity check that the pushdown did what we expected" ); // Scan with filter pushdown, stats are inexact @@ -127,8 +126,9 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec1.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Byte size is absent because we cannot estimate the output size + // of the Arrow data since there are variable length columns. + Precision::Absent, ); assert_eq!(get_static_cache_size(&state1), 1); @@ -142,8 +142,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec2.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); assert_eq!(get_static_cache_size(&state2), 1); @@ -157,8 +157,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec3.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); // List same file no increase assert_eq!(get_static_cache_size(&state1), 1); @@ -196,12 +196,9 @@ async fn list_files_with_session_level_cache() { //Session 1 first time list files assert_eq!(get_list_file_cache_size(&state1), 0); let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); - let data_source_exec = exec1.as_any().downcast_ref::().unwrap(); + let data_source_exec = exec1.downcast_ref::().unwrap(); let data_source = data_source_exec.data_source(); - let parquet1 = data_source - .as_any() - .downcast_ref::() - .unwrap(); + let parquet1 = data_source.downcast_ref::().unwrap(); assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet1.file_groups; @@ -212,12 +209,9 @@ async fn list_files_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_list_file_cache_size(&state2), 0); let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); - let data_source_exec = exec2.as_any().downcast_ref::().unwrap(); + let data_source_exec = exec2.downcast_ref::().unwrap(); let data_source = data_source_exec.data_source(); - let parquet2 = data_source - .as_any() - .downcast_ref::() - .unwrap(); + let parquet2 = data_source.downcast_ref::().unwrap(); assert_eq!(get_list_file_cache_size(&state2), 1); let fg2 = &parquet2.file_groups; @@ -228,12 +222,9 @@ async fn list_files_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_list_file_cache_size(&state1), 1); let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); - let data_source_exec = exec3.as_any().downcast_ref::().unwrap(); + let data_source_exec = exec3.downcast_ref::().unwrap(); let data_source = data_source_exec.data_source(); - let parquet3 = data_source - .as_any() - .downcast_ref::() - .unwrap(); + let parquet3 = data_source.downcast_ref::().unwrap(); assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet3.file_groups; diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 966f251613979..e6266b2c088d7 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -31,7 +31,7 @@ use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion::prelude::{ - col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, + Expr, ParquetReadOptions, SessionContext, col, lit, lit_timestamp_nano, }; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; @@ -63,7 +63,7 @@ async fn single_file() { // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() - .set_max_row_group_size(1024) + .set_max_row_group_row_count(Some(1024)) .build(); // Only create the parquet file once as it is fairly large @@ -220,7 +220,6 @@ async fn single_file() { } #[tokio::test] -#[allow(dead_code)] async fn single_file_small_data_pages() { let batches = read_parquet_test_data( "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", @@ -231,7 +230,7 @@ async fn single_file_small_data_pages() { // Set a low row count limit to improve page filtering let props = WriterProperties::builder() - .set_max_row_group_size(2048) + .set_max_row_group_row_count(Some(2048)) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); @@ -636,6 +635,43 @@ async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { config.options_mut().execution.parquet.pushdown_filters = true; let ctx = SessionContext::new_with_config(config); // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 7, // reads more than necessary from the cache as then another bitmap is applied + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_stats_issue_19561() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // force to get multiple batches to trigger repeated metric compound bug + config.options_mut().execution.batch_size = 1; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_default_selections_only() +-> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // forcing filter selections minimizes the number of rows read from the cache + config + .options_mut() + .execution + .parquet + .force_filter_selections = true; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled PredicateCacheTest { expected_inner_records: 8, expected_records: 4, diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45eadd..e96bd49b9ace9 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,20 +19,21 @@ use crate::parquet::utils::MetricsFinder; use arrow::{ array::{ - make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, Decimal128Array, + DictionaryArray, FixedSizeBinaryArray, Float64Array, Int8Array, Int16Array, + Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, + make_array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use arrow_schema::SchemaRef; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ - datasource::{provider_as_source, TableProvider}, + datasource::{TableProvider, provider_as_source}, physical_plan::metrics::MetricsSet, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; @@ -43,16 +44,18 @@ use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; +mod content_defined_chunking; mod custom_reader; #[cfg(feature = "parquet_encryption")] mod encryption; +mod expr_adapter; mod external_access_plan; mod file_statistics; mod filter_pushdown; +mod ordering; mod page_pruning; mod row_group_pruning; mod schema; -mod schema_adapter; mod schema_coercion; mod utils; @@ -109,6 +112,26 @@ struct ContextWithParquet { ctx: SessionContext, } +struct PruningMetric { + total_pruned: usize, + total_matched: usize, + total_fully_matched: usize, +} + +impl PruningMetric { + pub fn total_pruned(&self) -> usize { + self.total_pruned + } + + pub fn total_matched(&self) -> usize { + self.total_matched + } + + pub fn total_fully_matched(&self) -> usize { + self.total_fully_matched + } +} + /// The output of running one of the test cases struct TestOutput { /// The input query SQL @@ -126,8 +149,8 @@ struct TestOutput { impl TestOutput { /// retrieve the value of the named metric, if any fn metric_value(&self, metric_name: &str) -> Option { - if let Some((pruned, _matched)) = self.pruning_metric(metric_name) { - return Some(pruned); + if let Some(pm) = self.pruning_metric(metric_name) { + return Some(pm.total_pruned()); } self.parquet_metrics @@ -140,27 +163,33 @@ impl TestOutput { }) } - fn pruning_metric(&self, metric_name: &str) -> Option<(usize, usize)> { + fn pruning_metric(&self, metric_name: &str) -> Option { let mut total_pruned = 0; let mut total_matched = 0; + let mut total_fully_matched = 0; let mut found = false; for metric in self.parquet_metrics.iter() { let metric = metric.as_ref(); - if metric.value().name() == metric_name { - if let MetricValue::PruningMetrics { + if metric.value().name() == metric_name + && let MetricValue::PruningMetrics { pruning_metrics, .. } = metric.value() - { - total_pruned += pruning_metrics.pruned(); - total_matched += pruning_metrics.matched(); - found = true; - } + { + total_pruned += pruning_metrics.pruned(); + total_matched += pruning_metrics.matched(); + total_fully_matched += pruning_metrics.fully_matched(); + + found = true; } } if found { - Some((total_pruned, total_matched)) + Some(PruningMetric { + total_pruned, + total_matched, + total_fully_matched, + }) } else { None } @@ -172,27 +201,33 @@ impl TestOutput { } /// The number of row_groups pruned / matched by bloom filter - fn row_groups_bloom_filter(&self) -> Option<(usize, usize)> { + fn row_groups_bloom_filter(&self) -> Option { self.pruning_metric("row_groups_pruned_bloom_filter") } /// The number of row_groups matched by statistics fn row_groups_matched_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(_pruned, matched)| matched) + .map(|pm| pm.total_matched()) + } + + /// The number of row_groups fully matched by statistics + fn row_groups_fully_matched_statistics(&self) -> Option { + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_fully_matched()) } /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.pruning_metric("row_groups_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// Metric `files_ranges_pruned_statistics` tracks both pruned and matched count, /// for testing purpose, here it only aggregate the `pruned` count. fn files_ranges_pruned_statistics(&self) -> Option { self.pruning_metric("files_ranges_pruned_statistics") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) } /// The number of row_groups matched by bloom filter or statistics @@ -201,14 +236,13 @@ impl TestOutput { /// filter: 7 total -> 3 matched, this function returns 3 for the final matched /// count. fn row_groups_matched(&self) -> Option { - self.row_groups_bloom_filter() - .map(|(_pruned, matched)| matched) + self.row_groups_bloom_filter().map(|pm| pm.total_matched()) } /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_bloom_filter() - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) .zip(self.row_groups_pruned_statistics()) .map(|(a, b)| a + b) } @@ -216,7 +250,13 @@ impl TestOutput { /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { self.pruning_metric("page_index_rows_pruned") - .map(|(pruned, _matched)| pruned) + .map(|pm| pm.total_pruned()) + } + + /// The number of row groups pruned by limit pruning + fn limit_pruned_row_groups(&self) -> Option { + self.pruning_metric("limit_pruned_row_groups") + .map(|pm| pm.total_pruned()) } fn description(&self) -> String { @@ -232,20 +272,41 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario, unit: Unit) -> Self { - Self::with_config(scenario, unit, SessionConfig::new()).await + Self::with_config(scenario, unit, SessionConfig::new(), None, None).await + } + + /// Set custom schema and batches for the test + pub async fn with_custom_data( + scenario: Scenario, + unit: Unit, + schema: Arc, + batches: Vec, + ) -> Self { + Self::with_config( + scenario, + unit, + SessionConfig::new(), + Some(schema), + Some(batches), + ) + .await } async fn with_config( scenario: Scenario, unit: Unit, mut config: SessionConfig, + custom_schema: Option, + custom_batches: Option>, ) -> Self { // Use a single partition for deterministic results no matter how many CPUs the host has config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); - make_test_file_rg(scenario, row_per_group).await + config.options_mut().execution.parquet.pushdown_filters = true; + make_test_file_rg(scenario, row_per_group, custom_schema, custom_batches) + .await } Unit::Page(row_per_page) => { config = config.with_parquet_page_index_pruning(true); @@ -516,9 +577,9 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); - let v32: Vec = (start as _..end as _).collect(); - let v64: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as u16).collect(); + let v32: Vec = (start as u32..end as u32).collect(); + let v64: Vec = (start as u64..end as u64).collect(); RecordBatch::try_new( schema, vec![ @@ -652,6 +713,7 @@ fn make_date_batch(offset: Duration) -> RecordBatch { /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_bytearray_batch( name: &str, string_values: Vec<&str>, @@ -707,6 +769,7 @@ fn make_bytearray_batch( /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_names_batch(name: &str, service_name_values: Vec<&str>) -> RecordBatch { let num_rows = service_name_values.len(); let name: StringArray = std::iter::repeat_n(Some(name), num_rows).collect(); @@ -791,6 +854,7 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch { let keys = Int32Array::from_iter(0..strings.len() as i32); let small_keys = Int16Array::from_iter(0..strings.len() as i16); @@ -839,6 +903,7 @@ fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -1071,7 +1136,12 @@ fn create_data_batch(scenario: Scenario) -> Vec { } /// Create a test parquet file with various data types -async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTempFile { +async fn make_test_file_rg( + scenario: Scenario, + row_per_group: usize, + custom_schema: Option, + custom_batches: Option>, +) -> NamedTempFile { let mut output_file = tempfile::Builder::new() .prefix("parquet_pruning") .suffix(".parquet") @@ -1079,13 +1149,19 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .set_bloom_filter_enabled(true) .set_statistics_enabled(EnabledStatistics::Page) .build(); - let batches = create_data_batch(scenario); - let schema = batches[0].schema(); + let (batches, schema) = + if let (Some(schema), Some(batches)) = (custom_schema, custom_batches) { + (batches, schema) + } else { + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + (batches, schema) + }; let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/core/tests/parquet/ordering.rs b/datafusion/core/tests/parquet/ordering.rs new file mode 100644 index 0000000000000..faecb4ca6a861 --- /dev/null +++ b/datafusion/core/tests/parquet/ordering.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for ordering in Parquet sorting_columns metadata + +use datafusion::prelude::SessionContext; +use datafusion_common::Result; +use tempfile::tempdir; + +/// Test that CREATE TABLE ... WITH ORDER writes sorting_columns to Parquet metadata +#[tokio::test] +async fn test_create_table_with_order_writes_sorting_columns() -> Result<()> { + use parquet::file::reader::FileReader; + use parquet::file::serialized_reader::SerializedFileReader; + use std::fs::File; + + let ctx = SessionContext::new(); + let tmp_dir = tempdir()?; + let table_path = tmp_dir.path().join("sorted_table"); + std::fs::create_dir_all(&table_path)?; + + // Create external table with ordering + let create_table_sql = format!( + "CREATE EXTERNAL TABLE sorted_data (a INT, b VARCHAR) \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER (a ASC NULLS FIRST, b DESC NULLS LAST)", + table_path.display() + ); + ctx.sql(&create_table_sql).await?; + + // Insert sorted data + ctx.sql("INSERT INTO sorted_data VALUES (1, 'x'), (2, 'y'), (3, 'z')") + .await? + .collect() + .await?; + + // Find the parquet file that was written + let parquet_files: Vec<_> = std::fs::read_dir(&table_path)? + .filter_map(|e| e.ok()) + .filter(|e| e.path().extension().is_some_and(|ext| ext == "parquet")) + .collect(); + + assert!( + !parquet_files.is_empty(), + "Expected at least one parquet file in {}", + table_path.display() + ); + + // Read the parquet file and verify sorting_columns metadata + let file = File::open(parquet_files[0].path())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + // Check that row group has sorting_columns + let row_group = metadata.row_group(0); + let sorting_columns = row_group.sorting_columns(); + + assert!( + sorting_columns.is_some(), + "Expected sorting_columns in row group metadata" + ); + let sorting = sorting_columns.unwrap(); + assert_eq!(sorting.len(), 2, "Expected 2 sorting columns"); + + // First column: a ASC NULLS FIRST (column_idx = 0) + assert_eq!(sorting[0].column_idx, 0, "First sort column should be 'a'"); + assert!( + !sorting[0].descending, + "First column should be ASC (descending=false)" + ); + assert!( + sorting[0].nulls_first, + "First column should have NULLS FIRST" + ); + + // Second column: b DESC NULLS LAST (column_idx = 1) + assert_eq!(sorting[1].column_idx, 1, "Second sort column should be 'b'"); + assert!( + sorting[1].descending, + "Second column should be DESC (descending=true)" + ); + assert!( + !sorting[1].nulls_first, + "Second column should have NULLS LAST" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 27bee10234b57..a41803191ad05 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -20,26 +20,29 @@ use std::sync::Arc; use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; -use arrow::array::RecordBatch; -use datafusion::datasource::file_format::parquet::ParquetFormat; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion::execution::context::SessionState; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; +use datafusion::physical_plan::metrics::MetricValue; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_expr::create_physical_expr; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use futures::StreamExt; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::path::Path; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; async fn get_parquet_exec( state: &SessionState, @@ -67,26 +70,19 @@ async fn get_parquet_exec( .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) .with_predicate(predicate) .with_enable_page_index(true) .with_pushdown_filters(pushdown_filters), ); - let base_config = FileScanConfigBuilder::new(object_store_url, schema, source) + let base_config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); @@ -370,281 +366,367 @@ async fn prune_date64() { } macro_rules! int_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - // null count min max - // page-0 0 -5 -1 - // page-1 0 -4 0 - // page-2 0 0 4 - // page-3 0 5 9 - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} < 1", $bits), - Some(0), - Some(5), - 11, - 5, - ) - .await; - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where -i{} > -1", $bits), - Some(0), - Some(5), - 11, - 5, - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} > 8", $bits), - Some(0), - Some(15), - 1, - 5, - ) - .await; - - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where -i{} < -8", $bits), - Some(0), - Some(15), - 1, - 5, - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} = 1", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where abs(i{}) = 1", $bits), - Some(0), - Some(0), - 3, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{}+1 = 1", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where 1-i{} > 1", $bits), - Some(0), - Some(0), - 9, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} in (1)", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} not in (1)", $bits), - Some(0), - Some(0), - 19, - 5 - ) - .await; - } + ($bits:expr, $fn_lt:ident, $fn_gt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_complex_expr_subtract:ident, $fn_eq_in_list:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + // null count min max + // page-0 0 -5 -1 + // page-1 0 -4 0 + // page-2 0 0 4 + // page-3 0 5 9 + async fn $fn_lt() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} < 1", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} > -1", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; } - } + + #[tokio::test] + async fn $fn_gt() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} > 8", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} < -8", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} = 1", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + test_prune( + Scenario::Int, + &format!( + "SELECT * FROM t where abs(i{}) = 1 and i{} = 1", + $bits, $bits + ), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1", $bits), + Some(0), + Some(0), + 3, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{}+1 = 1", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr_subtract() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where 1-i{} > 1", $bits), + Some(0), + Some(0), + 9, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} in (1)", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} not in (1)", $bits), + Some(0), + Some(0), + 19, + 5, + ) + .await; + } + }; } -int_tests!(8); -int_tests!(16); -int_tests!(32); -int_tests!(64); +int_tests!( + 8, + prune_int8_lt, + prune_int8_gt, + prune_int8_eq, + prune_int8_scalar_fun_and_eq, + prune_int8_scalar_fun, + prune_int8_complex_expr, + prune_int8_complex_expr_subtract, + prune_int8_eq_in_list, + prune_int8_eq_in_list_negated +); +int_tests!( + 16, + prune_int16_lt, + prune_int16_gt, + prune_int16_eq, + prune_int16_scalar_fun_and_eq, + prune_int16_scalar_fun, + prune_int16_complex_expr, + prune_int16_complex_expr_subtract, + prune_int16_eq_in_list, + prune_int16_eq_in_list_negated +); +int_tests!( + 32, + prune_int32_lt, + prune_int32_gt, + prune_int32_eq, + prune_int32_scalar_fun_and_eq, + prune_int32_scalar_fun, + prune_int32_complex_expr, + prune_int32_complex_expr_subtract, + prune_int32_eq_in_list, + prune_int32_eq_in_list_negated +); +int_tests!( + 64, + prune_int64_lt, + prune_int64_gt, + prune_int64_eq, + prune_int64_scalar_fun_and_eq, + prune_int64_scalar_fun, + prune_int64_complex_expr, + prune_int64_complex_expr_subtract, + prune_int64_eq_in_list, + prune_int64_eq_in_list_negated +); macro_rules! uint_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - // null count min max - // page-0 0 0 4 - // page-1 0 1 5 - // page-2 0 5 9 - // page-3 0 250 254 - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} < 6", $bits), - Some(0), - Some(5), - 11, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} > 253", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} = 6", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{}+1 = 6", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} in (6)", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (6)" prune nothing - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} not in (6)", $bits), - Some(0), - Some(0), - 19, - 5 - ) - .await; - } + ($bits:expr, $fn_lt:ident, $fn_gt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_eq_in_list:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + // null count min max + // page-0 0 0 4 + // page-1 0 1 5 + // page-2 0 5 9 + // page-3 0 250 254 + async fn $fn_lt() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} < 6", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; } - } + + #[tokio::test] + async fn $fn_gt() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} > 253", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} = 6", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + test_prune( + Scenario::UInt, + &format!( + "SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", + $bits, $bits + ), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{}+1 = 6", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} in (6)", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (6)" prune nothing + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} not in (6)", $bits), + Some(0), + Some(0), + 19, + 5, + ) + .await; + } + }; } -uint_tests!(8); -uint_tests!(16); -uint_tests!(32); -uint_tests!(64); +uint_tests!( + 8, + prune_uint8_lt, + prune_uint8_gt, + prune_uint8_eq, + prune_uint8_scalar_fun_and_eq, + prune_uint8_scalar_fun, + prune_uint8_complex_expr, + prune_uint8_eq_in_list, + prune_uint8_eq_in_list_negated +); +uint_tests!( + 16, + prune_uint16_lt, + prune_uint16_gt, + prune_uint16_eq, + prune_uint16_scalar_fun_and_eq, + prune_uint16_scalar_fun, + prune_uint16_complex_expr, + prune_uint16_eq_in_list, + prune_uint16_eq_in_list_negated +); +uint_tests!( + 32, + prune_uint32_lt, + prune_uint32_gt, + prune_uint32_eq, + prune_uint32_scalar_fun_and_eq, + prune_uint32_scalar_fun, + prune_uint32_complex_expr, + prune_uint32_eq_in_list, + prune_uint32_eq_in_list_negated +); +uint_tests!( + 64, + prune_uint64_lt, + prune_uint64_gt, + prune_uint64_eq, + prune_uint64_scalar_fun_and_eq, + prune_uint64_scalar_fun, + prune_uint64_complex_expr, + prune_uint64_eq_in_list, + prune_uint64_eq_in_list_negated +); #[tokio::test] // null count min max @@ -968,3 +1050,56 @@ fn cast_count_metric(metric: MetricValue) -> Option { _ => None, } } + +#[tokio::test] +async fn test_parquet_opener_without_page_index() { + // Defines a simple schema and batch + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + // Create a temp file + let file = tempfile::Builder::new() + .suffix(".parquet") + .tempfile() + .unwrap(); + let path = file.path().to_str().unwrap().to_string(); + + // Write parquet WITHOUT page index + // The default WriterProperties does not write page index, but we set it explicitly + // to be robust against future changes in defaults as requested by reviewers. + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::None) + .build(); + + let file_fs = std::fs::File::create(&path).unwrap(); + let mut writer = ArrowWriter::try_new(file_fs, batch.schema(), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + // Setup SessionContext with PageIndex enabled + // This triggers the ParquetOpener to try and load page index if available + let config = SessionConfig::new().with_parquet_page_index_pruning(true); + + let ctx = SessionContext::new_with_config(config); + + // Register the table + ctx.register_parquet("t", &path, Default::default()) + .await + .unwrap(); + + // Query the table + // If the bug exists, this might fail because Opener tries to load PageIndex forcefully + let df = ctx.sql("SELECT * FROM t").await.unwrap(); + let batches = df + .collect() + .await + .expect("Failed to read parquet file without page index"); + + // We expect this to succeed, but currently it might fail + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); +} diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 0411298055f26..3ec3541af977a 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -18,8 +18,12 @@ //! This file contains an end to end test of parquet pruning. It writes //! data into a parquet file and then verifies row groups are pruned as //! expected. +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::SessionConfig; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use itertools::Itertools; use crate::parquet::Unit::RowGroup; @@ -30,10 +34,12 @@ struct RowGroupPruningTest { query: String, expected_errors: Option, expected_row_group_matched_by_statistics: Option, + expected_row_group_fully_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, + expected_limit_pruned_row_groups: Option, expected_rows: usize, } impl RowGroupPruningTest { @@ -45,9 +51,11 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_fully_matched_by_statistics: None, expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, + expected_limit_pruned_row_groups: None, expected_rows: 0, } } @@ -76,6 +84,15 @@ impl RowGroupPruningTest { self } + // Set the expected fully matched row groups by statistics + fn with_fully_matched_by_stats( + mut self, + fully_matched_by_stats: Option, + ) -> Self { + self.expected_row_group_fully_matched_by_statistics = fully_matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; @@ -99,6 +116,11 @@ impl RowGroupPruningTest { self } + fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option) -> Self { + self.expected_limit_pruned_row_groups = pruned_by_limit; + self + } + /// Set the number of expected rows from the output of this test fn with_expected_rows(mut self, rows: usize) -> Self { self.expected_rows = rows; @@ -135,15 +157,74 @@ impl RowGroupPruningTest { ); let bloom_filter_metrics = output.row_groups_bloom_filter(); assert_eq!( - bloom_filter_metrics.map(|(_pruned, matched)| matched), + bloom_filter_metrics.as_ref().map(|pm| pm.total_matched()), self.expected_row_group_matched_by_bloom_filter, "mismatched row_groups_matched_bloom_filter", ); assert_eq!( - bloom_filter_metrics.map(|(pruned, _matched)| pruned), + bloom_filter_metrics.map(|pm| pm.total_pruned()), self.expected_row_group_pruned_by_bloom_filter, "mismatched row_groups_pruned_bloom_filter", ); + + assert_eq!( + output.result_rows, + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, + output.description(), + ); + } + + // Execute the test with the current configuration + async fn test_row_group_prune_with_custom_data( + self, + schema: Arc, + batches: Vec, + max_row_per_group: usize, + ) { + let output = ContextWithParquet::with_custom_data( + self.scenario, + RowGroup(max_row_per_group), + schema, + batches, + ) + .await + .query(&self.query) + .await; + + println!("{}", output.description()); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation error" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); + assert_eq!( + output.row_groups_fully_matched_statistics(), + self.expected_row_group_fully_matched_by_statistics, + "mismatched row_groups_fully_matched_statistics", + ); + assert_eq!( + output.row_groups_pruned_statistics(), + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); + assert_eq!( + output.limit_pruned_row_groups(), + self.expected_limit_pruned_row_groups, + "mismatched limit_pruned_row_groups", + ); assert_eq!( output.result_rows, self.expected_rows, @@ -289,11 +370,16 @@ async fn prune_disabled() { let expected_rows = 10; let config = SessionConfig::new().with_parquet_pruning(false); - let output = - ContextWithParquet::with_config(Scenario::Timestamps, RowGroup(5), config) - .await - .query(query) - .await; + let output = ContextWithParquet::with_config( + Scenario::Timestamps, + RowGroup(5), + config, + None, + None, + ) + .await + .query(query) + .await; println!("{}", output.description()); // This should not prune any @@ -313,321 +399,365 @@ async fn prune_disabled() { // https://github.com/apache/datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on Int8 and Int16 columns are still buggy. macro_rules! int_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(3) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(9) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; - } + ($bits:expr, $fn_lt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_complex_expr_subtract:ident, $fn_eq_in_list:ident, $fn_eq_in_list_2:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + async fn $fn_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!( + "SELECT * FROM t where abs(i{}) = 1 and i{} = 1", + $bits, $bits + )) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr_subtract() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; } }; } // int8/int16 are incorrect: https://github.com/apache/datafusion/issues/9779 -int_tests!(32); -int_tests!(64); +int_tests!( + 32, + prune_int32_lt, + prune_int32_eq, + prune_int32_scalar_fun_and_eq, + prune_int32_scalar_fun, + prune_int32_complex_expr, + prune_int32_complex_expr_subtract, + prune_int32_eq_in_list, + prune_int32_eq_in_list_2, + prune_int32_eq_in_list_negated +); +int_tests!( + 64, + prune_int64_lt, + prune_int64_eq, + prune_int64_scalar_fun_and_eq, + prune_int64_scalar_fun, + prune_int64_complex_expr, + prune_int64_complex_expr_subtract, + prune_int64_eq_in_list, + prune_int64_eq_in_list_2, + prune_int64_eq_in_list_negated +); // $bits: number of bits of the integer to test (8, 16, 32, 64) // $correct_bloom_filters: if false, replicates the // https://github.com/apache/datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on UInt8 and UInt16 columns are still buggy. macro_rules! uint_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} < 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(3)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} = 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 25", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{}+1 = 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} in (6)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} in (100)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} not in (6)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_pruned_files(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; - } + ($bits:expr, $fn_lt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_eq_in_list:ident, $fn_eq_in_list_2:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + async fn $fn_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} < 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} = 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!( + "SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", + $bits, $bits + )) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 25", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{}+1 = 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} in (6)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(4)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} not in (6)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; } }; } // uint8/uint16 are incorrect: https://github.com/apache/datafusion/issues/9779 -uint_tests!(32); -uint_tests!(64); +uint_tests!( + 32, + prune_uint32_lt, + prune_uint32_eq, + prune_uint32_scalar_fun_and_eq, + prune_uint32_scalar_fun, + prune_uint32_complex_expr, + prune_uint32_eq_in_list, + prune_uint32_eq_in_list_2, + prune_uint32_eq_in_list_negated +); +uint_tests!( + 64, + prune_uint64_lt, + prune_uint64_eq, + prune_uint64_scalar_fun_and_eq, + prune_uint64_scalar_fun, + prune_uint64_complex_expr, + prune_uint64_eq_in_list, + prune_uint64_eq_in_list_2, + prune_uint64_eq_in_list_negated +); #[tokio::test] async fn prune_int32_eq_large_in_list() { @@ -1636,3 +1766,240 @@ async fn test_bloom_filter_decimal_dict() { .test_row_group_prune() .await; } + +// Helper function to create a batch with a single Int32 column. +fn make_i32_batch( + name: &str, + values: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)])); + let array: ArrayRef = Arc::new(Int32Array::from(values)); + RecordBatch::try_new(schema, vec![array]).map_err(DataFusionError::from) +} + +// Helper function to create a batch with two Int32 columns +fn make_two_col_i32_batch( + name_a: &str, + name_b: &str, + values_a: Vec, + values_b: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new(name_a, DataType::Int32, false), + Field::new(name_b, DataType::Int32, false), + ])); + let array_a: ArrayRef = Arc::new(Int32Array::from(values_a)); + let array_b: ArrayRef = Arc::new(Int32Array::from(values_b)); + RecordBatch::try_new(schema, vec![array_a, array_b]).map_err(DataFusionError::from) +} + +#[tokio::test] +async fn test_limit_pruning_basic() -> datafusion_common::error::Result<()> { + // Scenario: Simple integer column, multiple row groups + // Query: SELECT c1 FROM t WHERE c1 = 0 LIMIT 2 + // We expect 2 rows in total. + + // Row Group 0: c1 = [0, -2] -> Partially matched, 1 row + // Row Group 1: c1 = [1, 2] -> Fully matched, 2 rows + // Row Group 2: c1 = [3, 4] -> Fully matched, 2 rows + // Row Group 3: c1 = [5, 6] -> Fully matched, 2 rows + // Row Group 4: c1 = [-1, -2] -> Not matched + + // If limit = 2, and RG1 is fully matched and has 2 rows, we should + // only scan RG1 and prune other row groups + // RG4 is pruned by statistics. RG2 and RG3 are pruned by limit. + // So 2 row groups are effectively pruned due to limit pruning. + + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let query = "SELECT c1 FROM t WHERE c1 >= 0 LIMIT 2"; + + let batches = vec![ + make_i32_batch("c1", vec![0, -2])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![-1, -2])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) // Assuming Scenario::Int can handle this data + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(2) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_limit_pruned_row_groups(Some(3)) + .test_row_group_prune_with_custom_data(schema, batches, 2) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_complex_filter() -> datafusion_common::error::Result<()> { + // Test Case 1: Complex filter with two columns (a = 1 AND b > 1 AND b < 4) + // Row Group 0: a=[1,1,1], b=[0,2,3] -> Partially matched, 2 rows match (b=2,3) + // Row Group 1: a=[1,1,1], b=[2,2,2] -> Fully matched, 3 rows + // Row Group 2: a=[1,1,1], b=[2,3,3] -> Fully matched, 3 rows + // Row Group 3: a=[1,1,1], b=[2,2,3] -> Fully matched, 3 rows + // Row Group 4: a=[2,2,2], b=[2,2,2] -> Not matched (a != 1) + // Row Group 5: a=[1,1,1], b=[5,6,7] -> Not matched (b >= 4) + + // With LIMIT 5, we need RG1 (3 rows) + RG2 (2 rows from 3) = 5 rows + // RG4 and RG5 should be pruned by statistics + // RG3 should be pruned by limit + // RG0 is partially matched, so it depends on the order + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let query = "SELECT a, b FROM t WHERE a = 1 AND b > 1 AND b < 4 LIMIT 5"; + + let batches = vec![ + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![0, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 3, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![2, 2, 2], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![5, 6, 7])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(5) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 are matched + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(2)) // RG4,5 are pruned + .with_limit_pruned_row_groups(Some(2)) // RG0, RG3 is pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_multiple_fully_matched() +-> datafusion_common::error::Result<()> { + // Test Case 2: Limit requires multiple fully matched row groups + // Row Group 0: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 1: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 2: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 3: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 4: a=[1,2,3,4] -> Not matched + + // With LIMIT 8, we need RG0 (4 rows) + RG1 (4 rows) 8 rows + // RG2,3 should be pruned by limit + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 5 LIMIT 8"; + + let batches = vec![ + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![1, 2, 3, 4])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(8) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(2)) // RG2,3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_no_fully_matched() -> datafusion_common::error::Result<()> { + // Test Case 3: No fully matched row groups - all are partially matched + // Row Group 0: a=[1,2,3] -> Partially matched, 1 row (a=2) + // Row Group 1: a=[2,3,4] -> Partially matched, 1 row (a=2) + // Row Group 2: a=[2,5,6] -> Partially matched, 1 row (a=2) + // Row Group 3: a=[2,7,8] -> Partially matched, 1 row (a=2) + // Row Group 4: a=[9,10,11] -> Not matched + + // With LIMIT 3, we need to scan RG0,1,2 to get 3 matching rows + // Cannot prune much by limit since all matching RGs are partial + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 2 LIMIT 3"; + + let batches = vec![ + make_i32_batch("a", vec![1, 2, 3])?, + make_i32_batch("a", vec![2, 3, 4])?, + make_i32_batch("a", vec![2, 5, 6])?, + make_i32_batch("a", vec![2, 7, 8])?, + make_i32_batch("a", vec![9, 10, 11])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(3) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // RG3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error::Result<()> +{ + // Test Case 4: Limit exceeds all fully matched rows, need partially matched + // Row Group 0: a=[10,11,12,12] -> Partially matched, 1 row (a=10) + // Row Group 1: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 2: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 3: a=[10,13,14,11] -> Partially matched, 1 row (a=10) + // Row Group 4: a=[20,21,22,22] -> Not matched + + // With LIMIT 10, we need RG1 (4) + RG2 (4) = 8 from fully matched + // Still need 2 more, so we need to scan partially matched RG0 and RG3 + // All matching row groups should be scanned, only RG4 pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 10 LIMIT 10"; + + let batches = vec![ + make_i32_batch("a", vec![10, 11, 12, 12])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 13, 14, 11])?, + make_i32_batch("a", vec![20, 21, 22, 22])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(10) // Total: 1 + 4 + 4 + 1 = 10 + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // No limit pruning since we need all RGs + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + Ok(()) +} diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs deleted file mode 100644 index 40fc6176e212b..0000000000000 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ /dev/null @@ -1,553 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use arrow::array::{record_batch, RecordBatch, RecordBatchOptions}; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; -use bytes::{BufMut, BytesMut}; -use datafusion::assert_batches_eq; -use datafusion::common::Result; -use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableConfigExt, -}; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::DataFusionError; -use datafusion_common::{ColumnStatistics, ScalarValue}; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, -}; -use datafusion_datasource::ListingTableUrl; -use datafusion_datasource_parquet::source::ParquetSource; -use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr_adapter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, - PhysicalExprAdapterFactory, -}; -use itertools::Itertools; -use object_store::{memory::InMemory, path::Path, ObjectStore}; -use parquet::arrow::ArrowWriter; - -async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { - let mut out = BytesMut::new().writer(); - { - let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - } - let data = out.into_inner().freeze(); - store.put(&Path::from(path), data.into()).await.unwrap(); -} - -#[derive(Debug)] -struct CustomSchemaAdapterFactory; - -impl SchemaAdapterFactory for CustomSchemaAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(CustomSchemaAdapter { - logical_file_schema: projected_table_schema, - }) - } -} - -#[derive(Debug)] -struct CustomSchemaAdapter { - logical_file_schema: SchemaRef, -} - -impl SchemaAdapter for CustomSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - for (idx, field) in file_schema.fields().iter().enumerate() { - if field.name() == self.logical_file_schema.field(index).name() { - return Some(idx); - } - } - None - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let projection = (0..file_schema.fields().len()).collect_vec(); - Ok(( - Arc::new(CustomSchemaMapper { - logical_file_schema: Arc::clone(&self.logical_file_schema), - }), - projection, - )) - } -} - -#[derive(Debug)] -struct CustomSchemaMapper { - logical_file_schema: SchemaRef, -} - -impl SchemaMapper for CustomSchemaMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - let mut output_columns = - Vec::with_capacity(self.logical_file_schema.fields().len()); - for field in self.logical_file_schema.fields() { - if let Some(array) = batch.column_by_name(field.name()) { - output_columns.push(cast_with_options( - array, - field.data_type(), - &CastOptions::default(), - )?); - } else { - // Create a new array with the default value for the field type - let default_value = match field.data_type() { - DataType::Int64 => ScalarValue::Int64(Some(0)), - DataType::Utf8 => ScalarValue::Utf8(Some("a".to_string())), - _ => unimplemented!("Unsupported data type: {}", field.data_type()), - }; - output_columns - .push(default_value.to_array_of_size(batch.num_rows()).unwrap()); - } - } - let batch = RecordBatch::try_new_with_options( - Arc::clone(&self.logical_file_schema), - output_columns, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - ) - .unwrap(); - Ok(batch) - } - - fn map_column_statistics( - &self, - _file_col_statistics: &[ColumnStatistics], - ) -> Result> { - Ok(vec![ - ColumnStatistics::new_unknown(); - self.logical_file_schema.fields().len() - ]) - } -} - -// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with the default value for the field type -#[derive(Debug)] -struct CustomPhysicalExprAdapterFactory; - -impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { - fn create( - &self, - logical_file_schema: SchemaRef, - physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(CustomPhysicalExprAdapter { - logical_file_schema: Arc::clone(&logical_file_schema), - physical_file_schema: Arc::clone(&physical_file_schema), - inner: Arc::new(DefaultPhysicalExprAdapter::new( - logical_file_schema, - physical_file_schema, - )), - }) - } -} - -#[derive(Debug, Clone)] -struct CustomPhysicalExprAdapter { - logical_file_schema: SchemaRef, - physical_file_schema: SchemaRef, - inner: Arc, -} - -impl PhysicalExprAdapter for CustomPhysicalExprAdapter { - fn rewrite(&self, mut expr: Arc) -> Result> { - expr = expr - .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - let field_name = column.name(); - if self - .physical_file_schema - .field_with_name(field_name) - .ok() - .is_none() - { - let field = self - .logical_file_schema - .field_with_name(field_name) - .map_err(|_| { - DataFusionError::Plan(format!( - "Field '{field_name}' not found in logical file schema", - )) - })?; - // If the field does not exist, create a default value expression - // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests - let default_value = match field.data_type() { - DataType::Int64 => ScalarValue::Int64(Some(1)), - DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), - _ => unimplemented!( - "Unsupported data type: {}", - field.data_type() - ), - }; - return Ok(Transformed::yes(Arc::new( - expressions::Literal::new(default_value), - ))); - } - } - - Ok(Transformed::no(expr)) - }) - .data()?; - self.inner.rewrite(expr) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - assert!( - partition_values.is_empty(), - "Partition values are not supported in this test" - ); - Arc::new(self.clone()) - } -} - -#[tokio::test] -async fn test_custom_schema_adapter_and_custom_expression_adapter() { - let batch = - record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch, store.clone(), path).await; - - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, false), - Field::new("c2", DataType::Utf8, true), - ])); - - let mut cfg = SessionConfig::new() - // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation - .with_collect_statistics(false) - .with_parquet_pruning(false) - .with_parquet_page_index_pruning(false); - cfg.options_mut().execution.parquet.pushdown_filters = true; - let ctx = SessionContext::new_with_config(cfg); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - assert!( - !ctx.state() - .config_mut() - .options_mut() - .execution - .collect_statistics - ); - assert!(!ctx.state().config().collect_statistics()); - - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)) - .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); - - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") - .await - .unwrap() - .collect() - .await - .unwrap(); - - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // Test using a custom schema adapter and no explicit physical expr adapter - // This should use the custom schema adapter both for projections and predicate pushdown - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'a'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| a | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // Do the same test but with a custom physical expr adapter - // Now the default schema adapter will be used for projections, but the custom physical expr adapter will be used for predicate pushdown - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // If we use both then the custom physical expr adapter will be used for predicate pushdown and the custom schema adapter will be used for projections - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)) - .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| a | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); -} - -/// A test schema adapter factory that adds prefix to column names -#[derive(Debug)] -struct PrefixAdapterFactory { - prefix: String, -} - -impl SchemaAdapterFactory for PrefixAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(PrefixAdapter { - input_schema: projected_table_schema, - prefix: self.prefix.clone(), - }) - } -} - -/// A test schema adapter that adds prefix to column names -#[derive(Debug)] -struct PrefixAdapter { - input_schema: SchemaRef, - prefix: String, -} - -impl SchemaAdapter for PrefixAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.input_schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - for (file_idx, file_field) in file_schema.fields().iter().enumerate() { - if self.input_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); - } - } - - // Create a schema mapper that adds a prefix to column names - #[derive(Debug)] - struct PrefixSchemaMapping { - // Keep only the prefix field which is actually used in the implementation - prefix: String, - } - - impl SchemaMapper for PrefixSchemaMapping { - fn map_batch(&self, batch: RecordBatch) -> Result { - // Create a new schema with prefixed field names - let prefixed_fields: Vec = batch - .schema() - .fields() - .iter() - .map(|field| { - Field::new( - format!("{}{}", self.prefix, field.name()), - field.data_type().clone(), - field.is_nullable(), - ) - }) - .collect(); - let prefixed_schema = Arc::new(Schema::new(prefixed_fields)); - - // Create a new batch with the prefixed schema but the same data - let options = RecordBatchOptions::default(); - RecordBatch::try_new_with_options( - prefixed_schema, - batch.columns().to_vec(), - &options, - ) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - // For testing, just return the input statistics - Ok(stats.to_vec()) - } - } - - Ok(( - Arc::new(PrefixSchemaMapping { - prefix: self.prefix.clone(), - }), - projection, - )) - } -} - -#[test] -fn test_apply_schema_adapter_with_factory() { - // Create a schema - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create a parquet source - let source = ParquetSource::default(); - - // Create a file scan config with source that has a schema adapter factory - let factory = Arc::new(PrefixAdapterFactory { - prefix: "test_".to_string(), - }); - - let file_source = source.clone().with_schema_adapter_factory(factory).unwrap(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema.clone(), - file_source, - ) - .build(); - - // Apply schema adapter to a new source - let result_source = source.apply_schema_adapter(&config).unwrap(); - - // Verify the adapter was applied - assert!(result_source.schema_adapter_factory().is_some()); - - // Create adapter and test it produces expected schema - let adapter_factory = result_source.schema_adapter_factory().unwrap(); - let adapter = adapter_factory.create(schema.clone(), schema.clone()); - - // Create a dummy batch to test the schema mapping - let dummy_batch = RecordBatch::new_empty(schema.clone()); - - // Get the file schema (which is the same as the table schema in this test) - let (mapper, _) = adapter.map_schema(&schema).unwrap(); - - // Apply the mapping to get the output schema - let mapped_batch = mapper.map_batch(dummy_batch).unwrap(); - let output_schema = mapped_batch.schema(); - - // Check the column names have the prefix - assert_eq!(output_schema.field(0).name(), "test_id"); - assert_eq!(output_schema.field(1).name(), "test_name"); -} - -#[test] -fn test_apply_schema_adapter_without_factory() { - // Create a schema - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create a parquet source - let source = ParquetSource::default(); - - // Convert to Arc - let file_source: Arc = Arc::new(source.clone()); - - // Create a file scan config without a schema adapter factory - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema.clone(), - file_source, - ) - .build(); - - // Apply schema adapter function - should pass through the source unchanged - let result_source = source.apply_schema_adapter(&config).unwrap(); - - // Verify no adapter was applied - assert!(result_source.schema_adapter_factory().is_none()); -} diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 9be391a9108e6..6f7e2e328d0c3 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -18,16 +18,16 @@ use std::sync::Arc; use arrow::array::{ - types::Int32Type, ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, - StringArray, + ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; use datafusion::test::object_store::local_unpartitioned_file; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -62,14 +62,10 @@ async fn multi_parquet_coercion() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let source = Arc::new(ParquetSource::default()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .build(); + let source = Arc::new(ParquetSource::new(file_schema.clone())); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .build(); let parquet_exec = DataSourceExec::from_data_source(conf); @@ -122,11 +118,11 @@ async fn multi_parquet_coercion_projection() { ])); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_group(file_group) .with_projection_indices(Some(vec![1, 0, 2])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/parquet/utils.rs b/datafusion/core/tests/parquet/utils.rs index 24b6cadc148f8..77bc808f1ea08 100644 --- a/datafusion/core/tests/parquet/utils.rs +++ b/datafusion/core/tests/parquet/utils.rs @@ -20,7 +20,7 @@ use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanVisitor, accept}; /// Find the metrics from the first DataSourceExec encountered in the plan #[derive(Debug)] @@ -47,13 +47,12 @@ impl MetricsFinder { impl ExecutionPlanVisitor for MetricsFinder { type Error = std::convert::Infallible; fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - self.metrics = data_source_exec.metrics(); - } + { + self.metrics = data_source_exec.metrics(); } // stop searching once we have found the metrics Ok(self.metrics.is_none()) diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index a79d743cb253d..808e163b08369 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -20,26 +20,38 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::TestAggregate; use arrow::array::Int32Array; +use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::memory::MemTable; use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::assert_batches_eq; use datafusion_common::cast::as_int64_array; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, Statistics}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::TaskContext; +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::Operator; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::PhysicalGroupBy; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common; +use datafusion_physical_plan::displayable; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; /// Mock data using a MemorySourceConfig which has an exact count statistic fn mock_data() -> Result> { @@ -71,7 +83,7 @@ async fn assert_count_optim_success( let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?; // A ProjectionExec is a sign that the count optimization was applied - assert!(optimized.as_any().is::()); + assert!(optimized.is::()); // run both the optimized and nonoptimized plan let optimized_result = @@ -268,7 +280,7 @@ async fn test_count_inexact_stat() -> Result<()> { let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); + assert!(optimized.is::()); Ok(()) } @@ -312,7 +324,232 @@ async fn test_count_with_nulls_inexact_stat() -> Result<()> { let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); + assert!(optimized.is::()); + + Ok(()) +} + +/// Tests that TopK aggregation correctly handles UTF-8 (string) types in both grouping keys and aggregate values. +/// +/// The TopK optimization is designed to efficiently handle `GROUP BY ... ORDER BY aggregate LIMIT n` queries +/// by maintaining only the top K groups during aggregation. However, not all type combinations are supported. +/// +/// This test verifies two scenarios: +/// 1. **Supported case**: UTF-8 grouping key with numeric aggregate (max/min) - should use TopK optimization +/// 2. **Unsupported case**: UTF-8 grouping key with UTF-8 aggregate value - must gracefully fall back to +/// standard aggregation without panicking +/// +/// The fallback behavior is critical because attempting to use TopK with unsupported types could cause +/// runtime panics. This test ensures the optimizer correctly detects incompatible types and chooses +/// the appropriate execution path. +#[tokio::test] +async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().optimizer.enable_topk_aggregation = true; + let ctx = SessionContext::new_with_config(config); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("g", DataType::Utf8, false), + Field::new("val_str", DataType::Utf8, false), + Field::new("val_num", DataType::Int64, false), + ])), + vec![ + Arc::new(StringArray::from(vec!["a", "b", "a"])), + Arc::new(StringArray::from(vec!["alpha", "bravo", "charlie"])), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ], + )?; + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + // Supported path: numeric min/max with UTF-8 grouping should still use TopK aggregation + // and return correct results. + let supported_df = ctx + .sql("SELECT g, max(val_num) AS m FROM t GROUP BY g ORDER BY m DESC LIMIT 1") + .await?; + let supported_batches = supported_df.collect().await?; + assert_batches_eq!( + &[ + "+---+---+", + "| g | m |", + "+---+---+", + "| a | 3 |", + "+---+---+" + ], + &supported_batches + ); + + // Unsupported TopK value type: string min/max should fall back without panicking. + let unsupported_df = ctx + .sql("SELECT g, max(val_str) AS s FROM t GROUP BY g ORDER BY s DESC LIMIT 1") + .await?; + let unsupported_plan = unsupported_df.clone().create_physical_plan().await?; + let unsupported_batches = unsupported_df.collect().await?; + + // Ensure the plan avoided the TopK-specific stream implementation. + let plan_display = displayable(unsupported_plan.as_ref()) + .indent(true) + .to_string(); + assert!( + !plan_display.contains("GroupedTopKAggregateStream"), + "Unsupported UTF-8 aggregate value should not use TopK: {plan_display}" + ); + + assert_batches_eq!( + &[ + "+---+---------+", + "| g | s |", + "+---+---------+", + "| a | charlie |", + "+---+---------+" + ], + &unsupported_batches + ); + + Ok(()) +} + +#[tokio::test] +async fn test_count_distinct_optimization() -> Result<()> { + struct TestCase { + name: &'static str, + distinct_count: Precision, + use_column_expr: bool, + expect_optimized: bool, + expected_value: Option, + } + + let cases = vec![ + TestCase { + name: "exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: true, + expect_optimized: true, + expected_value: Some(42), + }, + TestCase { + name: "absent statistics", + distinct_count: Precision::Absent, + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "inexact statistics", + distinct_count: Precision::Inexact(42), + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "non-column expression with exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: false, + expect_optimized: false, + expected_value: None, + }, + ]; + + for case in cases { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let statistics = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + distinct_count: case.distinct_count, + null_count: Precision::Exact(10), + ..Default::default() + }, + ColumnStatistics::default(), + ], + }; + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(ParquetSource::new(Arc::clone(&schema))), + ) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_statistics(statistics) + .build(); + + let source: Arc = DataSourceExec::from_data_source(config); + let schema = source.schema(); + + let (agg_args, alias): (Vec>, _) = + if case.use_column_expr { + (vec![expressions::col("a", &schema)?], "COUNT(DISTINCT a)") + } else { + ( + vec![expressions::binary( + expressions::col("a", &schema)?, + Operator::Plus, + expressions::col("b", &schema)?, + &schema, + )?], + "COUNT(DISTINCT a + b)", + ) + }; + + let count_distinct_expr = AggregateExprBuilder::new(count_udaf(), agg_args) + .schema(Arc::clone(&schema)) + .alias(alias) + .distinct() + .build()?; + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr.clone())], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + if case.expect_optimized { + assert!( + optimized.is::(), + "'{}': expected ProjectionExec", + case.name + ); + + if let Some(expected_val) = case.expected_value { + let task_ctx = Arc::new(TaskContext::default()); + let result = common::collect(optimized.execute(0, task_ctx)?).await?; + assert_eq!(result.len(), 1, "'{}': expected 1 batch", case.name); + assert_eq!( + as_int64_array(result[0].column(0)).unwrap().values(), + &[expected_val], + "'{}': unexpected value", + case.name + ); + } + } else { + assert!( + optimized.is::(), + "'{}': expected AggregateExec (not optimized)", + case.name + ); + } + } Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 9c76f6ab6f58b..9e63c341c92d9 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -29,18 +29,18 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::ExecutionPlan; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { @@ -191,7 +191,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { // should combine the Partial/Final AggregateExecs to the Single AggregateExec assert_optimized!( plan, - @ " + @ r" AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet " @@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { schema, ) .unwrap() - .with_limit(Some(5)), + .with_limit_options(Some(LimitOptions::new(5))), ); let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index db011c4be43ab..78bb02ab1108b 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -26,36 +26,42 @@ use crate::physical_optimizer::test_utils::{ sort_preserving_merge_exec, union_exec, }; -use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; +use arrow::array::{RecordBatch, UInt8Array, UInt64Array}; use arrow::compute::SortOptions; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; +use datafusion::datasource::MemTable; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::error::Result; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; +use datafusion_common::config::CsvOptions; +use datafusion_common::error::Result; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, OrderingRequirements, PhysicalSortExpr, }; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + +use datafusion_physical_expr::Distribution; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; @@ -66,8 +72,7 @@ use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - displayable, get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, - PlanProperties, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, displayable, }; use insta::Settings; @@ -119,7 +124,7 @@ macro_rules! assert_plan { struct SortRequiredExec { input: Arc, expr: LexOrdering, - cache: PlanProperties, + cache: Arc, } impl SortRequiredExec { @@ -131,7 +136,7 @@ impl SortRequiredExec { Self { input, expr: requirement, - cache, + cache: Arc::new(cache), } } @@ -169,11 +174,7 @@ impl ExecutionPlan for SortRequiredExec { "SortRequiredExec" } - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -202,6 +203,20 @@ impl ExecutionPlan for SortRequiredExec { ))) } + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } + fn execute( &self, _partition: usize, @@ -209,12 +224,104 @@ impl ExecutionPlan for SortRequiredExec { ) -> Result { unreachable!(); } +} + +#[derive(Debug)] +struct SinglePartitionMaintainsOrderExec { + input: Arc, + cache: Arc, +} + +impl SinglePartitionMaintainsOrderExec { + fn new(input: Arc) -> Self { + let cache = Self::compute_properties(&input); + Self { + input, + cache: Arc::new(cache), + } + } + + fn compute_properties(input: &Arc) -> PlanProperties { + PlanProperties::new( + input.equivalence_properties().clone(), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ) + } +} + +impl DisplayAs for SinglePartitionMaintainsOrderExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SinglePartitionMaintainsOrderExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } +} + +impl ExecutionPlan for SinglePartitionMaintainsOrderExec { + fn name(&self) -> &'static str { + "SinglePartitionMaintainsOrderExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition] + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + let child = children.pop().unwrap(); + Ok(Arc::new(Self::new(child))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!(); } } +fn single_partition_maintains_order_exec( + input: Arc, +) -> Arc { + Arc::new(SinglePartitionMaintainsOrderExec::new(input)) +} + fn parquet_exec() -> Arc { parquet_exec_with_sort(schema(), vec![]) } @@ -229,8 +336,7 @@ fn parquet_exec_multiple_sorted( ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema())), ) .with_file_groups(vec![ FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), @@ -247,14 +353,19 @@ fn csv_exec() -> Arc { } fn csv_exec_with_sort(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -265,17 +376,22 @@ fn csv_exec_multiple() -> Arc { // Created a sorted parquet exec with multiple files fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file_groups(vec![ - FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), - FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), - ]) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file_groups(vec![ + FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), + FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), + ]) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -340,6 +456,71 @@ fn aggregate_exec_with_alias( ) } +fn partitioned_count_aggregate_exec( + input: Arc, + group_alias_pairs: Vec<(String, String)>, + count_column: &str, +) -> Arc { + let input_schema = input.schema(); + let group_by_expr = group_alias_pairs + .iter() + .map(|(column, alias)| { + ( + col(column, &input_schema).unwrap() as Arc, + alias.clone(), + ) + }) + .collect::>(); + let partial_group_by = PhysicalGroupBy::new_single(group_by_expr.clone()); + let final_group_by = PhysicalGroupBy::new_single( + group_by_expr + .iter() + .enumerate() + .map(|(idx, (_expr, alias))| { + ( + Arc::new(Column::new(alias, idx)) as Arc, + alias.clone(), + ) + }) + .collect::>(), + ); + + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![col(count_column, &input_schema).unwrap()], + ) + .schema(Arc::clone(&input_schema)) + .alias(format!("COUNT({count_column})")) + .build() + .unwrap(), + )]; + + let partial = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + partial_group_by, + aggr_expr.clone(), + vec![None], + input, + Arc::clone(&input_schema), + ) + .unwrap(), + ); + + Arc::new( + AggregateExec::try_new( + AggregateMode::FinalPartitioned, + final_group_by, + aggr_expr, + vec![None], + Arc::clone(&partial) as _, + partial.schema(), + ) + .unwrap(), + ) +} + fn hash_join_exec( left: Arc, right: Arc, @@ -469,83 +650,6 @@ impl TestConfig { self } - // This be deleted in https://github.com/apache/datafusion/pull/18185 - /// Perform a series of runs using the current [`TestConfig`], - /// assert the expected plan result, - /// and return the result plan (for potential subsequent runs). - fn run( - &self, - expected_lines: &[&str], - plan: Arc, - optimizers_to_run: &[Run], - ) -> Result> { - let expected_lines: Vec<&str> = expected_lines.to_vec(); - - // Add the ancillary output requirements operator at the start: - let optimizer = OutputRequirements::new_add_mode(); - let mut optimized = optimizer.optimize(plan.clone(), &self.config)?; - - // This file has 2 rules that use tree node, apply these rules to original plan consecutively - // After these operations tree nodes should be in a consistent state. - // This code block makes sure that these rules doesn't violate tree node integrity. - { - let adjusted = if self.config.optimizer.top_down_join_key_reordering { - // Run adjust_input_keys_ordering rule - let plan_requirements = - PlanWithKeyRequirements::new_default(plan.clone()); - let adjusted = plan_requirements - .transform_down(adjust_input_keys_ordering) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - adjusted.plan - } else { - // Run reorder_join_keys_to_inputs rule - plan.clone() - .transform_up(|plan| { - Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) - }) - .data()? - }; - - // Then run ensure_distribution rule - DistributionContext::new_default(adjusted) - .transform_up(|distribution_context| { - ensure_distribution(distribution_context, &self.config) - }) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - } - - for run in optimizers_to_run { - optimized = match run { - Run::Distribution => { - let optimizer = EnforceDistribution::new(); - optimizer.optimize(optimized, &self.config)? - } - Run::Sorting => { - let optimizer = EnforceSorting::new(); - optimizer.optimize(optimized, &self.config)? - } - }; - } - - // Remove the ancillary output requirements operator when done: - let optimizer = OutputRequirements::new_remove_mode(); - let optimized = optimizer.optimize(optimized, &self.config)?; - - // Now format correctly - let actual_lines = get_plan_string(&optimized); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" - ); - - Ok(optimized) - } - /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, /// and return the result plan (for potential subsequent runs). @@ -695,16 +799,13 @@ fn multi_hash_joins() -> Result<()> { assert_plan!(plan_distrib, @r" HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, // Should include 4 RepartitionExecs @@ -713,16 +814,13 @@ fn multi_hash_joins() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, }; @@ -767,16 +865,13 @@ fn multi_hash_joins() -> Result<()> { assert_plan!(plan_distrib, @r" HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@1, c@2)] HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); } @@ -787,16 +882,13 @@ fn multi_hash_joins() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@6, c@2)] RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, @@ -857,15 +949,12 @@ fn multi_joins_after_alias() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, c@2)] ProjectionExec: expr=[a@0 as a1, a@0 as a2] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -888,15 +977,12 @@ fn multi_joins_after_alias() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a2@1, c@2)] ProjectionExec: expr=[a@0 as a1, a@0 as a2] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -946,15 +1032,12 @@ fn multi_joins_after_multi_alias() -> Result<()> { ProjectionExec: expr=[c1@0 as a] ProjectionExec: expr=[c@2 as c1] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -1175,23 +1258,19 @@ fn multi_hash_join_key_ordering() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(B@2, b1@6), (C@3, c@2), (AA@1, a1@5)] ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - " - ); + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); let plan_sort = test_config.to_plan(filter_top_join, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -1313,25 +1392,21 @@ fn reorder_join_keys_to_left_input() -> Result<()> { assert_eq!(captured_join_type, join_type.to_string()); insta::allow_duplicates! {insta::assert_snapshot!(modified_plan, @r" -HashJoinExec: mode=Partitioned, join_type=..., on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)] - ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)] - RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -");} + HashJoinExec: mode=Partitioned, join_type=..., on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)] + RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1445,25 +1520,21 @@ fn reorder_join_keys_to_right_input() -> Result<()> { let (_, plan_str) = hide_first(reordered.as_ref(), r"join_type=(\w+)", "join_type=..."); insta::allow_duplicates! {insta::assert_snapshot!(plan_str, @r" -HashJoinExec: mode=Partitioned, join_type=..., on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)] - ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)] - RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -");} + HashJoinExec: mode=Partitioned, join_type=..., on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)] + RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1503,15 +1574,6 @@ fn multi_smj_joins() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let join_plan = |shift| -> String { - format!( - "{}SortMergeJoin: join_type={join_type}, on=[(a@0, b1@1)]", - " ".repeat(shift) - ) - }; - let join_plan_indent2 = join_plan(2); - let join_plan_indent6 = join_plan(6); - let join_plan_indent10 = join_plan(10); // Top join on (a == c) let top_join_on = vec![( @@ -1520,235 +1582,220 @@ fn multi_smj_joins() -> Result<()> { )]; let top_join = sort_merge_join_exec(join.clone(), parquet_exec(), &top_join_on, &join_type); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(a@0, c@2)]"); - - let expected = match join_type { - // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => - vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases - // when mode is Inner, Left, LeftSemi, LeftAnti - // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test - // cases when mode is Inner, Left, LeftSemi, LeftAnti - _ => vec![ - top_join_plan.as_str(), - // Below 2 operators are differences introduced, when join mode is changed - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - &join_plan_indent6, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - }; - // TODO(wiedld): show different test result if enforce sorting first. - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - - let expected_first_sort_enforcement = match join_type { - // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => - vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases - // when mode is Inner, Left, LeftSemi, LeftAnti - // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional Hash Repartition and Roundrobin repartition after - // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti - _ => vec![ - top_join_plan.as_str(), - // Below 4 operators are differences introduced, when join mode is changed - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - &join_plan_indent10, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - }; - // TODO(wiedld): show different test result if enforce distribution first. - test_config.run( - &expected_first_sort_enforcement, - top_join, - &SORT_DISTRIB_DISTRIB, - )?; - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - // This time we use (b1 == c) for top join - // Join on (b1 == c) - let top_join_on = vec![( - Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, - )]; - let top_join = - sort_merge_join_exec(join, parquet_exec(), &top_join_on, &join_type); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(b1@6, c@2)]"); - - let expected = match join_type { - // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs - JoinType::Inner | JoinType::Right => vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs - JoinType::Left | JoinType::Full => vec![ - top_join_plan.as_str(), - " SortExec: expr=[b1@6 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", - &join_plan_indent6, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // this match arm cannot be reached - _ => unreachable!() - }; - // TODO(wiedld): show different test result if enforce sorting first. - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - - let expected_first_sort_enforcement = match join_type { - // Should include 6 RepartitionExecs (3 of them preserves order) and 3 SortExecs - JoinType::Inner | JoinType::Right => vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs - JoinType::Left | JoinType::Full => vec![ - top_join_plan.as_str(), - " RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@6 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - &join_plan_indent10, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // this match arm cannot be reached - _ => unreachable!() - }; + let mut settings = Settings::clone_current(); + settings.add_filter(&format!("join_type={join_type}"), "join_type=..."); - // TODO(wiedld): show different test result if enforce distribution first. - test_config.run( - &expected_first_sort_enforcement, - top_join, - &SORT_DISTRIB_DISTRIB, - )?; - } - _ => {} + #[rustfmt::skip] + insta::allow_duplicates! { + settings.bind(|| { + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoinExec + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional Hash Repartition after SortMergeJoinExec in contrast the test + // cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => { + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + } + + let plan_sort = test_config.to_plan(top_join.clone(), &SORT_DISTRIB_DISTRIB); + + match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoinExec + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional Hash Repartition and Roundrobin repartition after + // SortMergeJoinExec in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + } + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + )]; + let top_join = sort_merge_join_exec(join, parquet_exec(), &top_join_on, &join_type); + + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { + // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs + JoinType::Inner | JoinType::Right => { + // TODO(wiedld): show different test result if enforce sorting first. + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs + JoinType::Left | JoinType::Full => { + // TODO(wiedld): show different test result if enforce sorting first. + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortExec: expr=[b1@6 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // this match arm cannot be reached + _ => unreachable!() + } + + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + + match join_type { + // Should include 6 RepartitionExecs (3 of them preserves order) and 3 SortExecs + JoinType::Inner | JoinType::Right => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs + JoinType::Left | JoinType::Full => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@6 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // this match arm cannot be reached + _ => unreachable!() + } + } + _ => {} + } + }); } } - Ok(()) } @@ -1806,50 +1853,48 @@ fn smj_join_key_ordering() -> Result<()> { // Only two RepartitionExecs added let plan_distrib = test_config.to_plan(join.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] - SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] - ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] - AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] - RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@1 as a2, b@0 as b2] - AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] - RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(join, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] - RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC, a3@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] - ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] - AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] - RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC, a2@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@1 as a2, b@0 as b2] - AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] - RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + "); Ok(()) } @@ -1867,9 +1912,6 @@ fn merge_does_not_need_sort() -> Result<()> { // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); - // CoalesceBatchesExec to mimic behavior after a filter - let exec = Arc::new(CoalesceBatchesExec::new(exec, 4096)); - // Merge from multiple parquet files and keep the data sorted let exec: Arc = Arc::new(SortPreservingMergeExec::new(sort_key, exec)); @@ -1882,10 +1924,9 @@ fn merge_does_not_need_sort() -> Result<()> { let plan_distrib = test_config.to_plan(exec.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: // @@ -1896,11 +1937,10 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(exec, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -2077,11 +2117,11 @@ fn repartition_sorted_limit() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // data is sorted so can't repartition here let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2106,12 +2146,12 @@ fn repartition_sorted_limit_with_filter() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [c@2 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [c@2 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // We can use repartition here, ordering requirement by SortRequiredExec // is still satisfied. let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); @@ -2132,19 +2172,19 @@ fn repartition_ignores_limit() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // repartition should happen prior to the filter to maximize parallelism // Expect no repartition to happen for local limit (DataSourceExec) @@ -2162,13 +2202,13 @@ fn repartition_ignores_union() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect no repartition of DataSourceExec let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2191,9 +2231,9 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2219,9 +2259,9 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [c@2 ASC] - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // should not sort (as the data was already sorted) // should not repartition, since increased parallelism is not beneficial for SortPReservingMerge @@ -2229,10 +2269,10 @@ SortPreservingMergeExec: [c@2 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2257,11 +2297,11 @@ fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [c@2 ASC] - UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // // should not repartition / sort (as the data was already sorted) @@ -2269,12 +2309,12 @@ SortPreservingMergeExec: [c@2 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2301,11 +2341,11 @@ fn repartition_does_not_destroy_sort() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [d@3 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet -"); + SortRequiredExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); // during repartitioning ordering is preserved let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2341,13 +2381,13 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -UnionExec - SortRequiredExec: [c@2 ASC] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + SortRequiredExec: [c@2 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // union input 1: no repartitioning // union input 2: should repartition // @@ -2384,23 +2424,23 @@ fn repartition_transitively_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [sum@0 ASC] - SortExec: expr=[sum@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@0 + b@1 as sum] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [sum@0 ASC] + SortExec: expr=[sum@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[sum@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@0 + b@1 as sum] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[sum@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Since this projection is not trivial, increasing parallelism is beneficial Ok(()) @@ -2432,10 +2472,10 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [c@2 ASC] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortRequiredExec: [c@2 ASC] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Since this projection is trivial, increasing parallelism is not beneficial let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); @@ -2469,10 +2509,10 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Since this projection is trivial, increasing parallelism is not beneficial let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2494,12 +2534,12 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input to the sort (as it can benefit from additional parallelism) @@ -2507,12 +2547,12 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input of the filter (as it can benefit from additional parallelism) Ok(()) @@ -2543,13 +2583,13 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input to the sort (as it can benefit from additional parallelism) // repartition is lowest down @@ -2558,13 +2598,13 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -2584,11 +2624,11 @@ fn parallelization_single_partition() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2596,11 +2636,11 @@ AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2632,10 +2672,10 @@ fn parallelization_multiple_files() -> Result<()> { test_config_concurrency_3.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_3_distrib, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let plan_3_sort = test_config_concurrency_3.to_plan(plan.clone(), &SORT_DISTRIB_DISTRIB); assert_plan!(plan_3_distrib, plan_3_sort); @@ -2645,10 +2685,10 @@ SortRequiredExec: [a@0 ASC] test_config_concurrency_8.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_8_distrib, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let plan_8_sort = test_config_concurrency_8.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_8_distrib, plan_8_sort); @@ -2667,46 +2707,55 @@ fn parallelization_compressed_csv() -> Result<()> { FileCompressionType::UNCOMPRESSED, ]; - let expected_not_partitioned = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - - let expected_partitioned = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; + #[rustfmt::skip] + insta::allow_duplicates! { + for compression_type in compression_types { + let plan = aggregate_exec_with_alias( + DataSourceExec::from_data_source( + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_file_compression_type(compression_type) + .build(), + ), + vec![("a".to_string(), "a".to_string())], + ); + let test_config = TestConfig::default() + .with_query_execution_partitions(2) + .with_prefer_repartition_file_scans(10); + + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + if compression_type.is_compressed() { + // Compressed files cannot be partitioned + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + } else { + // Uncompressed files can be partitioned + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + } - for compression_type in compression_types { - let expected = if compression_type.is_compressed() { - &expected_not_partitioned[..] - } else { - &expected_partitioned[..] - }; - - let plan = aggregate_exec_with_alias( - DataSourceExec::from_data_source( - FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_file_compression_type(compression_type) - .build(), - ), - vec![("a".to_string(), "a".to_string())], - ); - let test_config = TestConfig::default() - .with_query_execution_partitions(2) - .with_prefer_repartition_file_scans(10); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); + } } Ok(()) } @@ -2726,23 +2775,23 @@ fn parallelization_two_partitions() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); - // Plan already has two partitions - let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Plan already has two partitions + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Plan already has two partitions let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2766,11 +2815,11 @@ fn parallelization_two_partitions_into_four() -> Result<()> { // Multiple source files split across partitions assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Multiple source files split across partitions let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2779,11 +2828,11 @@ AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); // Multiple source files split across partitions assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Multiple source files split across partitions let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2808,11 +2857,11 @@ fn parallelization_sorted_limit() -> Result<()> { let plan_parquet_distrib = test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // data is sorted so can't repartition here // Doesn't parallelize for SortExec without preserve_partitioning let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); @@ -2822,11 +2871,11 @@ GlobalLimitExec: skip=0, fetch=100 let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // data is sorted so can't repartition here // Doesn't parallelize for SortExec without preserve_partitioning let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); @@ -2857,14 +2906,14 @@ fn parallelization_limit_with_filter() -> Result<()> { // SortExec doesn't benefit from input partitioning assert_plan!(plan_parquet_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2875,14 +2924,14 @@ GlobalLimitExec: skip=0, fetch=100 // SortExec doesn't benefit from input partitioning assert_plan!(plan_csv_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2961,13 +3010,13 @@ fn parallelization_union_inputs() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Union doesn't benefit from input partitioning - no parallelism let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2976,13 +3025,13 @@ UnionExec let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Union doesn't benefit from input partitioning - no parallelism let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -3188,9 +3237,9 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> // data should not be repartitioned / resorted assert_plan!(plan_parquet_distrib, @r" -ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -3223,18 +3272,18 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); assert_plan!(plan_csv, @r" -SortPreservingMergeExec: [c2@1 ASC] - ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false -"); + SortPreservingMergeExec: [c2@1 ASC] + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false -"); + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); // Expected Outcome: // data should not be repartitioned / resorted let plan_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); @@ -3250,21 +3299,21 @@ fn remove_redundant_roundrobins() -> Result<()> { let physical_plan = repartition_exec(filter_exec(repartition)); assert_plan!(physical_plan, @r" -RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - FilterExec: c@2 = 0 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3292,11 +3341,11 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { // This is still satisfied since, after filter that column is constant. assert_plan!(plan_distrib, @r" -CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3321,11 +3370,76 @@ fn preserve_ordering_through_repartition() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [d@3 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet -"); + SortPreservingMergeExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); + + Ok(()) +} + +#[test] +fn preserve_ordering_for_streaming_sorted_aggregate() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }] + .into(); + let input = parquet_exec_multiple_sorted(vec![sort_key]); + let physical_plan = partitioned_count_aggregate_exec( + input, + vec![("a".to_string(), "a".to_string())], + "b", + ); + + let test_config = TestConfig::default().with_query_execution_partitions(2); + + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); + + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); + + Ok(()) +} + +#[test] +fn preserve_ordering_for_streaming_partially_sorted_aggregate() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }] + .into(); + let input = parquet_exec_multiple_sorted(vec![sort_key]); + let physical_plan = partitioned_count_aggregate_exec( + input, + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ], + "c", + ); + + let test_config = TestConfig::default().with_query_execution_partitions(2); + + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0]) + RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC + AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0]) + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3349,23 +3463,23 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3384,11 +3498,11 @@ fn no_need_for_sort_after_filter() -> Result<()> { let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. @@ -3420,24 +3534,24 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -3457,10 +3571,10 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3480,10 +3594,10 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { // Ordering requirement of sort required exec is NOT satisfied // by existing ordering at the source. assert_plan!(physical_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; @@ -3493,11 +3607,11 @@ SortRequiredExec: [a@0 ASC] // Since at the start of the rule ordering requirement is not satisfied // EnforceDistribution rule doesn't satisfy this requirement either. assert_plan!(dist_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3516,10 +3630,10 @@ fn put_sort_when_input_is_valid() -> Result<()> { // Ordering requirement of sort required exec is satisfied // by existing ordering at the source. assert_plan!(physical_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; @@ -3529,10 +3643,10 @@ SortRequiredExec: [a@0 ASC] // Since at the start of the rule ordering requirement is satisfied // EnforceDistribution rule satisfy this requirement also. assert_plan!(dist_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3556,10 +3670,10 @@ fn do_not_add_unnecessary_hash() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3586,14 +3700,14 @@ fn do_not_add_unnecessary_hash2() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Since hash requirements of this operator is satisfied. There shouldn't be // a hash repartition here let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); @@ -3607,17 +3721,15 @@ fn optimize_away_unnecessary_repartition() -> Result<()> { let physical_plan = coalesce_partitions_exec(repartition_exec(parquet_exec())); assert_plan!(physical_plan, @r" -CoalescePartitionsExec - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, - @r" -DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet"); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3631,23 +3743,23 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { ))); assert_plan!(physical_plan, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3671,29 +3783,29 @@ async fn test_distribute_sort_parquet() -> Result<()> { // prior to optimization, this is the starting plan assert_plan!(physical_plan, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the enforce distribution run does. let plan_distribution = test_config.to_plan(physical_plan.clone(), &[Run::Distribution]); assert_plan!(plan_distribution, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the sort parallelization (in enforce sorting), does after the enforce distribution changes let plan_both = test_config.to_plan(physical_plan, &[Run::Distribution, Run::Sorting]); assert_plan!(plan_both, @r" -SortPreservingMergeExec: [c@2 ASC] - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3720,10 +3832,10 @@ async fn test_distribute_sort_memtable() -> Result<()> { // this is the final, optimized plan assert_plan!(physical_plan, @r" -SortPreservingMergeExec: [id@0 ASC NULLS LAST] - SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] - DataSourceExec: partitions=3, partition_sizes=[34, 33, 33] -"); + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] + DataSourceExec: partitions=3, partition_sizes=[34, 33, 33] + "); Ok(()) } @@ -3781,7 +3893,6 @@ fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { // Verify the plan was transformed to CoalescePartitionsExec result .plan - .as_any() .downcast_ref::() .expect("Expected CoalescePartitionsExec"); @@ -3794,3 +3905,106 @@ fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { Ok(()) } + +/// When a parent requires SinglePartition and maintains input order, order-preserving +/// variants (e.g. SortPreservingMergeExec) should be kept so that ordering can +/// propagate to ancestors. Replacing them with CoalescePartitionsExec would destroy +/// ordering and force unnecessary sorts later. +#[test] +fn maintains_order_preserves_spm_for_single_partition() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, + options: SortOptions::default(), + }] + .into(); + + // GlobalLimitExec -> LocalLimitExec -> sorted multi-partition parquet + let plan: Arc = + limit_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); + + // Test EnforceDistribution in isolation: SPM should be preserved because + // GlobalLimitExec maintains input order. + let result = ensure_distribution_helper(plan, 10, false)?; + assert_plan!(result, + @r" + GlobalLimitExec: skip=0, fetch=100 + SortPreservingMergeExec: [c@2 ASC] + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + + Ok(()) +} + +/// Tests the cascading effect through a UnionExec with the full optimizer +/// pipeline and `prefer_existing_sort=true`. Each Union branch has an operator +/// that requires SinglePartition and maintains input order. SortPreservingMergeExec +/// should be preserved in each branch, allowing ordering to flow through to the +/// ancestor SortRequiredExec. +#[test] +fn maintains_order_preserves_spm_through_union_with_prefer_existing_sort() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, + options: SortOptions::default(), + }] + .into(); + + let branch1 = + single_partition_maintains_order_exec(parquet_exec_multiple_sorted(vec![ + sort_key.clone(), + ])); + let branch2 = + single_partition_maintains_order_exec(parquet_exec_multiple_sorted(vec![ + sort_key.clone(), + ])); + let plan = sort_required_exec_with_req(union_exec(vec![branch1, branch2]), sort_key); + + let test_config = TestConfig::default().with_prefer_existing_sort(); + + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortRequiredExec: [c@2 ASC] + UnionExec + SinglePartitionMaintainsOrderExec + SortPreservingMergeExec: [c@2 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + SinglePartitionMaintainsOrderExec + SortPreservingMergeExec: [c@2 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + + Ok(()) +} + +/// Verifies that `adjust_input_keys_ordering` returns `Transformed::no` +/// for a simple scan plan with no key requirements, avoiding an +/// unnecessary plan rebuild. +#[test] +fn adjust_input_keys_ordering_no_transform_for_scan() -> Result<()> { + let plan: Arc = parquet_exec(); + let requirements = PlanWithKeyRequirements::new_default(plan); + let result = adjust_input_keys_ordering(requirements)?; + assert!( + !result.transformed, + "expected Transformed::no for a scan plan with empty requirements" + ); + Ok(()) +} + +/// Verifies that `adjust_input_keys_ordering` applied via `transform_down` +/// over a filter -> scan tree returns `Transformed::no` when there are no +/// join/aggregate key requirements. +#[test] +fn adjust_input_keys_ordering_no_transform_for_filter_scan() -> Result<()> { + let plan: Arc = filter_exec(parquet_exec()); + let requirements = PlanWithKeyRequirements::new_default(plan); + let result = requirements.transform_down(adjust_input_keys_ordering)?; + assert!( + !result.transformed, + "expected Transformed::no for a filter->scan tree with no key requirements" + ); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e3a0eb7e1aa6f..40bcdbbd6efef 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -19,21 +19,21 @@ use std::sync::Arc; use crate::memory_limit::DummyStreamPartition; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, - check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, - create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, - hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + RequirementsTestExec, aggregate_exec, bounded_window_exec, + bounded_window_exec_with_partition, check_integrity, coalesce_partitions_exec, + create_test_schema, create_test_schema2, create_test_schema3, filter_exec, + global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, + parquet_exec_with_sort, projection_exec, repartition_exec, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, }; -use arrow::compute::SortOptions; +use arrow::compute::{SortOptions}; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, TableReference}; +use datafusion_common::{create_array, Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_expr_common::operator::Operator; @@ -58,24 +58,29 @@ use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion::prelude::*; -use arrow::array::{Int32Array, RecordBatch}; +use arrow::array::{record_batch, ArrayRef, Int32Array, RecordBatch}; use arrow::datatypes::{Field}; use arrow_schema::Schema; use datafusion_execution::TaskContext; use datafusion_catalog::streaming::StreamingTable; use futures::StreamExt; -use insta::{assert_snapshot, Settings}; +use insta::{Settings, assert_snapshot}; /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(false, 0, 0)), + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)), ) .with_file(PartitionedFile::new("x".to_string(), 100)); if let Some(ordering) = LexOrdering::new(sort_exprs) { @@ -361,8 +366,8 @@ async fn test_union_inputs_different_sorted2() -> Result<()> { #[tokio::test] // Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, @r" @@ -387,8 +392,8 @@ async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_reparti #[tokio::test] // Test with `repartition_sorts` disabled, causing a full resort of the data -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, @r" @@ -659,21 +664,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r" - Input Plan: + Input / Optimized Plan: SortPreservingMergeExec: [nullable_col@0 ASC] UnionExec SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet - - Optimized Plan: - SortPreservingMergeExec: [nullable_col@0 ASC] - UnionExec - SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet - SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet "); // Union preserves the inputs ordering, and we should not change any of the SortExecs under UnionExec @@ -773,8 +770,8 @@ async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { } #[tokio::test] -async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( -) -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns() +-> Result<()> { let schema = create_test_schema()?; let source = parquet_exec(schema.clone()); let ordering = [sort_expr_options( @@ -1072,8 +1069,8 @@ async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { } #[tokio::test] -async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement( -) -> Result<()> { +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement() +-> Result<()> { let schema = create_test_schema()?; let source = parquet_exec(schema.clone()); let ordering = [sort_expr_options( @@ -1259,7 +1256,8 @@ async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { let physical_plan = sort_preserving_merge_exec(ordering3, union); let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); - // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` + // Should not change the unnecessarily fine `SortExec`s because there are + // explicit limit nodes above the second sort. assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC] @@ -1342,12 +1340,12 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1359,13 +1357,13 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1432,12 +1430,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1449,12 +1447,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1466,13 +1464,13 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1515,13 +1513,13 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1542,12 +1540,12 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1626,13 +1624,13 @@ async fn test_with_lost_ordering_unbounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] Optimized Plan: SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] "); @@ -1644,13 +1642,13 @@ async fn test_with_lost_ordering_unbounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] Optimized Plan: SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] "); @@ -1669,7 +1667,7 @@ async fn test_with_lost_ordering_bounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false "); @@ -1681,14 +1679,14 @@ async fn test_with_lost_ordering_bounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false Optimized Plan: SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false "); @@ -1710,7 +1708,7 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { Input / Optimized Plan: SortExec: expr=[b@1 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false "); @@ -1739,13 +1737,13 @@ async fn test_pushdown_through_spm() -> Result<()> { Input Plan: SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false Optimized Plan: SortPreservingMergeExec: [a@0 ASC, b@1 ASC] SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false "); Ok(()) @@ -1769,7 +1767,7 @@ async fn test_window_multi_layer_requirement() -> Result<()> { BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false @@ -1847,9 +1845,7 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { )] .into(); let sort = sort_exec(ordering.clone(), source); - // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort, 128); - let window_agg = bounded_window_exec("non_nullable_col", ordering, coalesce_batches); + let window_agg = bounded_window_exec("non_nullable_col", ordering, sort); let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), @@ -1875,17 +1871,15 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { FilterExec: NOT non_nullable_col@1 SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - CoalesceBatchesExec: target_batch_size=128 - SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] - DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] Optimized Plan: WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] FilterExec: NOT non_nullable_col@1 BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] - CoalesceBatchesExec: target_batch_size=128 - SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] - DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] "#); Ok(()) @@ -1964,7 +1958,7 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] @@ -2011,7 +2005,7 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { AggregateExec: mode=Final, gby=[], aggr=[] SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [non_nullable_col@1 ASC] SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] @@ -2360,7 +2354,7 @@ async fn test_commutativity() -> Result<()> { assert_snapshot!(displayable(orig_plan.as_ref()).indent(true), @r#" SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] DataSourceExec: partitions=1, partition_sizes=[0] "#); @@ -2812,3 +2806,47 @@ async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_sort_with_streaming_table() -> Result<()> { + let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [1, 2, 3]))?; + + let ctx = SessionContext::new(); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let schema = batch.schema(); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT a FROM test_table GROUP BY a ORDER BY a"; + let results = ctx.sql(sql).await?.collect().await?; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let expected = create_array!(Int32, vec![1, 2, 3]) as ArrayRef; + assert_eq!(results[0].column(0), &expected); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs index ef233e222912c..de7611ff211a5 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs @@ -31,7 +31,7 @@ use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + BoundedWindowAggExec, WindowAggExec, create_window_expr, }; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode}; use insta::assert_snapshot; diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs new file mode 100644 index 0000000000000..5f64c9e4a5400 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -0,0 +1,3376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow::{ + array::record_batch, + datatypes::{DataType, Field, Schema, SchemaRef}, + util::pretty::pretty_format_batches, +}; +use arrow_schema::SortOptions; +use datafusion::{ + assert_batches_eq, + logical_expr::Operator, + physical_plan::{ + PhysicalExpr, + expressions::{BinaryExpr, Column, Literal}, + }, + prelude::{SessionConfig, SessionContext}, + scalar::ScalarValue, +}; +use datafusion_catalog::memory::DataSourceExec; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::{ + PartitionedFile, file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, +}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::ScalarUDF; +use datafusion_functions::math::random::RandomFunc; +use datafusion_functions_aggregate::{count::count_udaf, min_max::min_udaf}; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, expressions::col}; +use datafusion_physical_expr::{ + Partitioning, ScalarFunctionExpr, aggregate::AggregateExprBuilder, +}; +use datafusion_physical_optimizer::{ + PhysicalOptimizerRule, filter_pushdown::FilterPushdown, +}; +use datafusion_physical_plan::{ + ExecutionPlan, + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + coalesce_partitions::CoalescePartitionsExec, + collect, + filter::{FilterExec, FilterExecBuilder}, + projection::ProjectionExec, + repartition::RepartitionExec, + sorts::sort::SortExec, +}; + +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; +use datafusion_physical_plan::union::UnionExec; +use object_store::memory::InMemory; + +#[test] +fn test_pushdown_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_pushdown_volatile_functions_not_allowed() { + // Test that we do not push down filters with volatile functions + // Use random() as an example of a volatile function + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let cfg = Arc::new(ConfigOptions::default()); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("a", &schema()).unwrap()), + Operator::Eq, + Arc::new( + ScalarFunctionExpr::try_new( + Arc::new(ScalarUDF::from(RandomFunc::new())), + vec![], + &schema(), + cfg, + ) + .unwrap(), + ), + )) as Arc; + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + // expect the filter to not be pushed down + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + ", + ); +} + +/// Show that we can use config options to determine how to do pushdown. +#[test] +fn test_pushdown_into_scan_with_config_options() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; + + let mut cfg = ConfigOptions::default(); + insta::assert_snapshot!( + OptimizationTest::new( + Arc::clone(&plan), + FilterPushdown::new(), + false + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); + + cfg.execution.parquet.pushdown_filters = true; + insta::assert_snapshot!( + OptimizationTest::new( + plan, + FilterPushdown::new(), + true + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +// Inner-join part is covered by push_down_filter_parquet.slt::test_hashjoin_parent_filter_pushdown. +// The Left-join part stays in Rust: SQL's outer-join-elimination rewrites +// `LEFT JOIN ... WHERE ` into an INNER JOIN +// before physical filter pushdown runs, so the preserved-vs-non-preserved +// distinction this test exercises is not reachable via SQL. +#[tokio::test] +async fn test_static_filter_pushdown_through_hash_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Create filters that can be pushed down to different sides + // We need to create filters in the context of the join output schema + let join_schema = join.schema(); + + // Filter on build side column: a = 'aa' + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column: e = 'ba' + let right_filter = col_lit_predicate("e", "ba", &join_schema); + // Filter that references both sides: a = d (should not be pushed down) + let cross_filter = Arc::new(BinaryExpr::new( + col("a", &join_schema).unwrap(), + Operator::Eq, + col("d", &join_schema).unwrap(), + )) as Arc; + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); + let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) + as Arc; + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = d@3 + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = d@3 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba + " + ); + + // Test left join: filter on preserved (build) side is pushed down, + // filter on non-preserved (probe) side is NOT pushed down. + let join = Arc::new( + HashJoinExec::try_new( + TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(), + TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(), + vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )], + None, + &JoinType::Left, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on build side column (preserved): should be pushed down + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column (not preserved): should NOT be pushed down + let right_filter = col_lit_predicate("e", "ba", &join_schema); + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: e@4 = ba + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_filter_collapse() { + // filter should be pushed down into the parquet scan with two filters + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_filter_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, Arc::clone(&scan)) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + + // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // add a test where the filter is on a column that isn't included in the output + let projection = vec![1]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, scan) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_collapse_outer_fetch_preserved() { + // When the outer filter has fetch and inner does not, the merged filter should preserve fetch + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate2, filter1) + .with_fetch(Some(10)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar, fetch=10 + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - FilterExec: b@1 = bar AND a@0 = foo, fetch=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_collapse_inner_fetch_preserved() { + // When the inner filter has fetch and outer does not, the merged filter should preserve fetch + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new( + FilterExecBuilder::new(predicate1, scan) + .with_fetch(Some(5)) + .build() + .unwrap(), + ); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo, fetch=5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - FilterExec: b@1 = bar AND a@0 = foo, fetch=5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_collapse_both_fetch_uses_minimum() { + // When both filters have fetch, the merged filter should use the smaller (tighter) fetch. + // Inner fetch=5 is tighter than outer fetch=10, so the result should be fetch=5. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new( + FilterExecBuilder::new(predicate1, scan) + .with_fetch(Some(5)) + .build() + .unwrap(), + ); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate2, filter1) + .with_fetch(Some(10)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar, fetch=10 + - FilterExec: a@0 = foo, fetch=5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - FilterExec: b@1 = bar AND a@0 = foo, fetch=5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_with_fetch_fully_pushed_to_scan() { + // When a FilterExec has a fetch limit and all predicates are pushed down + // to a supportive DataSourceExec, the FilterExec is removed and the fetch + // must be propagated to the DataSourceExec. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, scan) + .with_fetch(Some(10)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, fetch=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], limit=10, file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_with_fetch_and_projection_fully_pushed_to_scan() { + // When a FilterExec has both fetch and projection, and all predicates are + // pushed down, the filter is replaced by a ProjectionExec and the fetch + // must still be propagated to the DataSourceExec. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, scan) + .with_fetch(Some(5)) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0], fetch=5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], limit=5, file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_with_fetch_partially_pushed_to_scan() { + // When a FilterExec has fetch and only some predicates are pushed down, + // the FilterExec remains with the unpushed predicate and keeps its fetch. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let pushed_predicate = col_lit_predicate("a", "foo", &schema()); + let volatile_predicate = { + let cfg = Arc::new(ConfigOptions::default()); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("a", &schema()).unwrap()), + Operator::Eq, + Arc::new( + ScalarFunctionExpr::try_new( + Arc::new(ScalarUDF::from(RandomFunc::new())), + vec![], + &schema(), + cfg, + ) + .unwrap(), + ), + )) as Arc + }; + // Combine: a = 'foo' AND a = random() + let combined = Arc::new(BinaryExpr::new( + pushed_predicate, + Operator::And, + volatile_predicate, + )) as Arc; + let plan = Arc::new( + FilterExecBuilder::new(combined, scan) + .with_fetch(Some(7)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo AND a@0 = random(), fetch=7 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = random(), fetch=7 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_with_fetch_not_pushed_to_unsupportive_scan() { + // When the DataSourceExec does not support pushdown, the FilterExec + // remains unchanged with its fetch intact. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, scan) + .with_fetch(Some(3)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, fetch=3 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - FilterExec: a@0 = foo, fetch=3 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_push_down_through_transparent_nodes() { + // expect the predicate to be pushed down into the DataSource + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + let repartition = Arc::new( + RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_pushdown_through_aggregates_on_grouping_columns() { + // Test that filters on grouping columns can be pushed through AggregateExec. + // This test has two filters: + // 1. An inner filter (a@0 = foo) below the aggregate - gets pushed to DataSource + // 2. An outer filter (b@1 = bar) above the aggregate - also gets pushed through because 'b' is a grouping column + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let filter = Arc::new( + FilterExecBuilder::new(col_lit_predicate("a", "foo", &schema()), scan) + .with_batch_size(10) + .build() + .unwrap(), + ); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, aggregate) + .with_batch_size(100) + .build() + .unwrap(), + ); + + // Both filters should be pushed down to the DataSource since both reference grouping columns + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=Sorted + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +/// Test various combinations of handling of child pushdown results +/// in an ExecutionPlan in combination with support/not support in a DataSource. +#[test] +fn test_node_handles_child_pushdown_result() { + // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource + // and a FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); + + // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); +} + +// Not portable to sqllogictest: requires manually constructing +// `SortExec(CoalescePartitionsExec(scan))`. A SQL `ORDER BY ... LIMIT` over a +// multi-partition scan plans as `SortPreservingMergeExec(SortExec(scan))` +// instead, so the filter-through-coalesce path this test exercises is not +// reachable via SQL. +#[tokio::test] +async fn test_topk_filter_passes_through_coalesce_partitions() { + // Create multiple batches for different partitions + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + + // Create a source that supports all batches + let source = Arc::new(TestSource::new(schema(), true, batches)); + + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file_groups(vec![ + // Partition 0 + FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), + // Partition 1 + FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), + ]) + .build(); + + let scan = DataSourceExec::from_data_source(base_config); + + // Add CoalescePartitionsExec to merge the two partitions + let coalesce = Arc::new(CoalescePartitionsExec::new(scan)) as Arc; + + // Add SortExec with TopK + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), + )]) + .unwrap(), + coalesce, + ) + .with_fetch(Some(1)), + ) as Arc; + + // Test optimization - the filter SHOULD pass through CoalescePartitionsExec + // if it properly implements from_children (not all_unsupported) + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); +} + +// Not portable to sqllogictest: this test pins `PartitionMode::Partitioned` +// by hand-wiring `RepartitionExec(Hash, 12)` on both join sides. A SQL +// INNER JOIN over small parquet inputs plans as `CollectLeft`, so the +// per-partition CASE filter this test exercises is not reachable via SQL. +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Rough sketch of the MRE we're trying to recreate: + // COPY (select i as k from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t1.parquet' + // STORED AS PARQUET; + // COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t2.parquet' + // STORED AS PARQUET; + // create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; + // create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; + // explain + // select * + // from t1 + // join t2 on t1.k = t2.k; + // +---------------+------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+------------------------------------------------------------+ + // | physical_plan | ┌───────────────────────────┐ | + // | | │ HashJoinExec │ | + // | | │ -------------------- ├──────────────┐ | + // | | │ on: (k = k) │ │ | + // | | └─────────────┬─────────────┘ │ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ RepartitionExec ││ RepartitionExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ partition_count(in->out): ││ partition_count(in->out): │ | + // | | │ 12 -> 12 ││ 12 -> 12 │ | + // | | │ ││ │ | + // | | │ partitioning_scheme: ││ partitioning_scheme: │ | + // | | │ Hash([k@0], 12) ││ Hash([k@0], 12) │ | + // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ DataSourceExec ││ DataSourceExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ files: 12 ││ files: 12 │ | + // | | │ format: parquet ││ format: parquet │ | + // | | │ ││ predicate: true │ | + // | | └───────────────────────────┘└───────────────────────────┘ | + // | | | + // +---------------+------------------------------------------------------------+ + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Probe side: DataSource -> RepartitionExec (Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + #[cfg(not(feature = "force_hash_collisions"))] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 5 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 8 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ] + " + ); + + // When hash collisions force all data into a single partition, we optimize away the CASE expression. + // This avoids calling create_hashes() for every row on the probe side, since hash % 1 == 0 always, + // meaning the WHEN 0 branch would always match. This optimization is also important for primary key + // joins or any scenario where all build-side data naturally lands in one partition. + #[cfg(feature = "force_hash_collisions")] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +// Not portable to sqllogictest: this test specifically pins a +// `RepartitionExec(Hash, 12)` between `HashJoinExec(CollectLeft)` and the +// probe-side scan to verify the dynamic filter link survives that boundary +// (regression for #17451). The same CollectLeft filter content and +// pushdown counters are already covered by the simpler slt port +// (push_down_filter_parquet.slt::test_hashjoin_dynamic_filter_pushdown). +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_same_column_names() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("build_val", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("probe_val", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &build_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let build_id_filter = col_lit_predicate("id", "aa", &join_schema); + let probe_val_filter = col_lit_predicate("probe_val", "x", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(build_id_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(probe_val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: probe_val@3 = x + - FilterExec: id@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true, predicate=id@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true, predicate=probe_val@1 = x + " + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_mark_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("val", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &left_schema).unwrap(), + col("id", &right_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftMark, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let left_filter = col_lit_predicate("val", "x", &join_schema); + let mark_filter = col_lit_predicate("mark", true, &join_schema); + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(mark_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: mark@2 = true + - FilterExec: val@1 = x + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: mark@2 = true + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true, predicate=val@1 = x + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + " + ); +} + +/// Test that filters on join key columns are pushed to both sides of semi/anti joins. +/// For LeftSemi/LeftAnti, the output only contains left columns, but filters on +/// join key columns can also be pushed to the right (non-preserved) side because +/// the equijoin condition guarantees the key values match. +#[test] +fn test_hashjoin_parent_filter_pushdown_semi_anti_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("w", DataType::Utf8, false), + ])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("k", &left_schema).unwrap(), + col("k", &right_schema).unwrap(), + )]; + + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftSemi, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on join key column: k = 'x' — should be pushed to BOTH sides + let key_filter = col_lit_predicate("k", "x", &join_schema); + // Filter on non-key column: v = 'y' — should only be pushed to the left side + let val_filter = col_lit_predicate("v", "y", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(key_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: v@1 = y + - FilterExec: k@0 = x + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true, predicate=k@0 = x AND v@1 = y + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true, predicate=k@0 = x + " + ); +} + +#[test] +fn test_filter_pushdown_through_union() { + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_mixed_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_does_not_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(false).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_with_fetch_fully_pushed_through_union() { + // When a FilterExec with fetch wraps a UnionExec and all predicates are + // pushed down, UnionExec does not support with_fetch, so a LocalLimitExec + // should be inserted to preserve the fetch limit. + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, union) + .with_fetch(Some(10)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo, fetch=10 + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - LocalLimitExec: fetch=10 + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_with_fetch_and_projection_fully_pushed_through_union() { + // When a FilterExec with both fetch and projection wraps a UnionExec and + // all predicates are pushed down, we should get a ProjectionExec on top of + // a LocalLimitExec wrapping the UnionExec. + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, union) + .with_fetch(Some(5)) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0], fetch=5 + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - LocalLimitExec: fetch=5 + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_with_fetch_not_fully_pushed_through_union() { + // When a FilterExec with fetch wraps a UnionExec but children don't support + // pushdown, the FilterExec remains with its fetch — no LocalLimitExec needed. + let scan1 = TestScanBuilder::new(schema()).with_support(false).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, union) + .with_fetch(Some(8)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo, fetch=8 + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - LocalLimitExec: fetch=8 + - UnionExec + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// Schema: +/// a: String +/// b: String +/// c: f64 +static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + let fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ]; + Arc::new(Schema::new(fields)) +}); + +fn schema() -> SchemaRef { + Arc::clone(&TEST_SCHEMA) +} + +// test_topk_with_projection_transformation_on_dyn_filter has been ported +// to datafusion/sqllogictest/test_files/push_down_filter_parquet.slt; see +// `topk_proj` fixture for the 4 representative cases (reorder, prune, +// expression, alias shadowing). The `run_projection_dyn_filter_case` +// harness was removed along with it. + +/// Returns a predicate that is a binary expression col = lit +fn col_lit_predicate( + column_name: &str, + scalar_value: impl Into, + schema: &Schema, +) -> Arc { + let scalar_value = scalar_value.into(); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema(column_name, schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(scalar_value)), + )) +} + +// ==== Aggregate Dynamic Filter tests ==== +// +// The end-to-end min/max dynamic filter cases (simple/min/max/mixed/all-nulls) +// have been ported to +// `datafusion/sqllogictest/test_files/push_down_filter_regression.slt`. +// The `run_aggregate_dyn_filter_case` harness used to drive them was removed +// along with the test functions. + +/// Non-partial (Single) aggregates should skip dynamic filter initialization. +#[test] +fn test_aggregate_dynamic_filter_not_created_for_single_mode() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(vec![]), + vec![min_expr.into()], + vec![None], + scan, + Arc::clone(&schema), + ) + .unwrap(), + ); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let formatted = format_plan_for_test(&optimized); + assert!( + !formatted.contains("DynamicFilter ["), + "dynamic filter should not be created for AggregateMode::Single: {formatted}" + ); +} + +#[test] +fn test_pushdown_filter_on_non_first_grouping_column() { + // Test that filters on non-first grouping columns are still pushed down + // SELECT a, b, count(*) as cnt FROM table GROUP BY a, b HAVING b = 'bar' + // The filter is on 'b' (second grouping column), should push down + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); +} + +#[test] +fn test_no_pushdown_grouping_sets_filter_on_missing_column() { + // Test that filters on columns missing from some grouping sets are NOT pushed through + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Create GROUPING SETS with (a, b) and (b) + let group_by = PhysicalGroupBy::new( + vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "b".to_string(), + ), + ], + vec![ + vec![false, false], // (a, b) - both present + vec![true, false], // (b) - a is NULL, b present + ], + true, + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on column 'a' which is missing in the second grouping set, should not be pushed down + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_pushdown_grouping_sets_filter_on_common_column() { + // Test that filters on columns present in ALL grouping sets ARE pushed through + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Create GROUPING SETS with (a, b) and (b) + let group_by = PhysicalGroupBy::new( + vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "b".to_string(), + ), + ], + vec![ + vec![false, false], // (a, b) - both present + vec![true, false], // (b) - a is NULL, b present + ], + true, + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on column 'b' which is present in all grouping sets will be pushed down + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); +} + +#[test] +fn test_pushdown_with_empty_group_by() { + // Test that filters can be pushed down when GROUP BY is empty (no grouping columns) + // SELECT count(*) as cnt FROM table WHERE a = 'foo' + // There are no grouping columns, so the filter should still push down + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Empty GROUP BY - no grouping columns + let group_by = PhysicalGroupBy::new_single(vec![]); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on 'a' + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + // The filter should be pushed down even with empty GROUP BY + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_pushdown_through_aggregate_with_reordered_input_columns() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + // Reorder scan output from (a, b, c) to (c, a, b) + let reordered_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Float64, false), + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let projection = Arc::new( + ProjectionExec::try_new( + vec![ + (col("c", &schema()).unwrap(), "c".to_string()), + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + scan, + ) + .unwrap(), + ); + + let aggregate_expr = vec![ + AggregateExprBuilder::new( + count_udaf(), + vec![col("c", &reordered_schema).unwrap()], + ) + .schema(reordered_schema.clone()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Group by a@1, b@2 (input indices in reordered schema) + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &reordered_schema).unwrap(), "a".to_string()), + (col("b", &reordered_schema).unwrap(), "b".to_string()), + ]); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr, + vec![None], + projection, + reordered_schema, + ) + .unwrap(), + ); + + // Filter on b@1 in aggregate's output schema (a@0, b@1, cnt@2) + // The grouping expr for b references input index 2, but output index is 1. + let agg_output_schema = aggregate.schema(); + let predicate = col_lit_predicate("b", "bar", &agg_output_schema); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + // The filter should be pushed down + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt] + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); +} + +#[test] +fn test_pushdown_through_aggregate_grouping_sets_with_reordered_input() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let reordered_schema = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Float64, false), + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let projection = Arc::new( + ProjectionExec::try_new( + vec![ + (col("c", &schema()).unwrap(), "c".to_string()), + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + scan, + ) + .unwrap(), + ); + + let aggregate_expr = vec![ + AggregateExprBuilder::new( + count_udaf(), + vec![col("c", &reordered_schema).unwrap()], + ) + .schema(reordered_schema.clone()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Use grouping sets (a, b) and (b). + let group_by = PhysicalGroupBy::new( + vec![ + (col("a", &reordered_schema).unwrap(), "a".to_string()), + (col("b", &reordered_schema).unwrap(), "b".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "b".to_string(), + ), + ], + vec![vec![false, false], vec![true, false]], + true, + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr, + vec![None], + projection, + reordered_schema, + ) + .unwrap(), + ); + + let agg_output_schema = aggregate.schema(); + + // Filter on b (present in all grouping sets) should be pushed down + let predicate = col_lit_predicate("b", "bar", &agg_output_schema); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate.clone()).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt] + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); + + // Filter on a (missing from second grouping set) should not be pushed down + let predicate = col_lit_predicate("a", "foo", &agg_output_schema); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt] + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt] + - ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); +} + +/// Regression test for https://github.com/apache/datafusion/issues/21065. +/// +/// Given a plan similar to the following, ensure that the filter is pushed down +/// through an AggregateExec whose input columns are reordered by a ProjectionExec. +#[test] +fn test_pushdown_with_computed_grouping_key() { + // Test filter pushdown with computed grouping expression + // SELECT (c + 1.0) as c_plus_1, count(*) FROM table WHERE c > 5.0 GROUP BY (c + 1.0) + + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let predicate = Arc::new(BinaryExpr::new( + col("c", &schema()).unwrap(), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Float64(Some(5.0)))), + )) as Arc; + let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + let c_plus_one = Arc::new(BinaryExpr::new( + col("c", &schema()).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))), + )) as Arc; + + let group_by = + PhysicalGroupBy::new_single(vec![(c_plus_one, "c_plus_1".to_string())]); + + let plan = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + // The filter should be pushed down because 'c' is extracted from the grouping expression (c + 1.0) + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] + - FilterExec: c@2 > 5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=c@2 > 5 + " + ); +} + +// Not portable to sqllogictest: in CollectLeft (the mode SQL picks for small +// data), an empty build side short-circuits the HashJoin and the probe scan +// is never executed, so its dynamic filter stays at `[ empty ]` rather than +// collapsing to `[ false ]`. The Rust test uses PartitionMode::Partitioned +// on a hand-wired plan, which does trigger the `false` path. +#[tokio::test] +async fn test_hashjoin_dynamic_filter_all_partitions_empty() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test scenario where all build-side partitions are empty + // This validates the code path that sets the filter to `false` when no rows can match + + // Create empty build side + let build_batches = vec![]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with some data + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac"]), + ("b", Utf8, ["ba", "bb", "bc"]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides + let partition_count = 4; + + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Apply the filter pushdown optimizer + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(plan, config.options()).unwrap(); + + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + // Execute all partitions (required for partitioned hash join coordination) + let _batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ false ] + " + ); +} + +// Not portable to sqllogictest: same reason as +// test_hashjoin_dynamic_filter_pushdown_partitioned — hand-wires +// PartitionMode::Partitioned, which SQL never picks for small parquet inputs. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Probe side: DataSource -> RepartitionExec (Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +// Ported to push_down_filter_parquet.slt (`hl_build`/`hl_probe` fixture). +// Rust version retained only because the slt port cannot hand-wire the +// RepartitionExec-above-probe shape this test uses; the hash_lookup vs +// IN (SET) invariant is captured in the slt port. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +// Not portable to sqllogictest: asserts on `HashJoinExec::dynamic_filter_for_test().is_used()` +// which is a debug-only API. The observable behavior (probe-side scan +// receiving the dynamic filter when the data source supports it) is +// already covered by the simpler CollectLeft port in push_down_filter_parquet.slt; +// the with_support(false) branch has no SQL analog (parquet always supports +// pushdown). +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_is_used() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test both cases: probe side with and without filter pushdown support + for (probe_supports_pushdown, expected_is_used) in [(false, false), (true, true)] { + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("b", Utf8, ["ba", "bb"])) + .unwrap(), + ]) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(probe_supports_pushdown) + .with_batches(vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]) + ) + .unwrap(), + ]) + .build(); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Apply filter pushdown optimization + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Get the HashJoinExec to check the dynamic filter + let hash_join = plan + .downcast_ref::() + .expect("Plan should be HashJoinExec"); + + // Verify that a dynamic filter was created + let dynamic_filter = hash_join + .dynamic_filter() + .expect("Dynamic filter should be created"); + + // Verify that is_used() returns the expected value based on probe side support. + // When probe_supports_pushdown=false: no consumer holds a reference (is_used=false) + // When probe_supports_pushdown=true: probe side holds a reference (is_used=true) + assert_eq!( + dynamic_filter.is_used(), + expected_is_used, + "is_used() should return {expected_is_used} when probe side support is {probe_supports_pushdown}" + ); + } +} + +/// Regression test for https://github.com/apache/datafusion/issues/20109. +/// +/// Not portable to sqllogictest: the regression specifically targets the +/// physical FilterPushdown rule running over *stacked* FilterExecs with +/// projections on a MemorySourceConfig. In SQL the logical optimizer +/// collapses the two filters before physical planning, so the stacked +/// FilterExec shape this test exercises is unreachable. +#[tokio::test] +async fn test_filter_with_projection_pushdown() { + use arrow::array::{Int64Array, RecordBatch, StringArray}; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::filter::FilterExecBuilder; + + // Create schema: [time, event, size] + let schema = Arc::new(Schema::new(vec![ + Field::new("time", DataType::Int64, false), + Field::new("event", DataType::Utf8, false), + Field::new("size", DataType::Int64, false), + ])); + + // Create sample data + let timestamps = vec![100i64, 200, 300, 400, 500]; + let events = vec!["Ingestion", "Ingestion", "Query", "Ingestion", "Query"]; + let sizes = vec![10i64, 20, 30, 40, 50]; + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(timestamps)), + Arc::new(StringArray::from(events)), + Arc::new(Int64Array::from(sizes)), + ], + ) + .unwrap(); + + // Create data source + let memory_exec = datafusion_datasource::memory::MemorySourceConfig::try_new_exec( + &[vec![batch]], + schema.clone(), + None, + ) + .unwrap(); + + // First FilterExec: time < 350 with projection=[event@1, size@2] + let time_col = col("time", &memory_exec.schema()).unwrap(); + let time_filter = Arc::new(BinaryExpr::new( + time_col, + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int64(Some(350)))), + )); + let filter1 = Arc::new( + FilterExecBuilder::new(time_filter, memory_exec) + .apply_projection(Some(vec![1, 2])) + .unwrap() + .build() + .unwrap(), + ); + + // Second FilterExec: event = 'Ingestion' with projection=[size@1] + let event_col = col("event", &filter1.schema()).unwrap(); + let event_filter = Arc::new(BinaryExpr::new( + event_col, + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "Ingestion".to_string(), + )))), + )); + let filter2 = Arc::new( + FilterExecBuilder::new(event_filter, filter1) + .apply_projection(Some(vec![1])) + .unwrap() + .build() + .unwrap(), + ); + + // Apply filter pushdown optimization + let config = ConfigOptions::default(); + let optimized_plan = FilterPushdown::new() + .optimize(Arc::clone(&filter2) as Arc, &config) + .unwrap(); + + // Execute the optimized plan - this should not error + let ctx = SessionContext::new(); + let result = collect(optimized_plan, ctx.task_ctx()).await.unwrap(); + + // Verify results: should return rows where time < 350 AND event = 'Ingestion' + // That's rows with time=100,200 (both have event='Ingestion'), so sizes 10,20 + let expected = [ + "+------+", "| size |", "+------+", "| 10 |", "| 20 |", "+------+", + ]; + assert_batches_eq!(expected, &result); +} + +/// Test that ExecutionPlan::apply_expressions() can discover dynamic filters across the plan tree. +/// +/// Not portable to sqllogictest: asserts by walking the plan tree with +/// `apply_expressions` + `downcast_ref::` and +/// counting nodes. Neither API is observable from SQL. +#[tokio::test] +async fn test_discover_dynamic_filters_via_expressions_api() { + use datafusion_common::JoinType; + use datafusion_common::tree_node::TreeNodeRecursion; + use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + fn count_dynamic_filters(plan: &Arc) -> usize { + let mut count = 0; + + // Check expressions from this node using apply_expressions + let _ = plan.apply_expressions(&mut |expr| { + if let Some(_df) = expr.downcast_ref::() { + count += 1; + } + Ok(TreeNodeRecursion::Continue) + }); + + // Recursively visit children + for child in plan.children() { + count += count_dynamic_filters(child); + } + + count + } + + // Create build side (left) + let build_batches = + vec![record_batch!(("a", Utf8, ["foo", "bar"]), ("b", Int32, [1, 2])).unwrap()]; + let build_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])); + let build_scan = TestScanBuilder::new(build_schema.clone()) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side (right) + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["foo", "bar", "baz", "qux"]), + ("c", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(probe_schema.clone()) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + vec![( + col("a", &build_schema).unwrap(), + col("a", &probe_schema).unwrap(), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Before optimization: no dynamic filters + let count_before = count_dynamic_filters(&plan); + assert_eq!( + count_before, 0, + "Before optimization, should have no dynamic filters" + ); + + // Apply filter pushdown optimization (this creates dynamic filters) + let mut config = ConfigOptions::default(); + config.optimizer.enable_dynamic_filter_pushdown = true; + config.execution.parquet.pushdown_filters = true; + let optimized_plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // After optimization: should discover dynamic filters + // We expect 2 dynamic filters: + // 1. In the HashJoinExec (producer) + // 2. In the DataSourceExec (consumer, pushed down to the probe side) + let count_after = count_dynamic_filters(&optimized_plan); + assert_eq!( + count_after, 2, + "After optimization, should discover exactly 2 dynamic filters (1 in HashJoinExec, 1 in DataSourceExec), found {count_after}" + ); +} + +// ==== Filter pushdown through SortExec tests ==== + +/// FilterExec above a plain SortExec (no fetch) should be pushed below it. +/// The scan supports pushdown, so the filter lands in the DataSourceExec. +#[test] +fn test_filter_pushdown_through_sort_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + )); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// FilterExec above a plain SortExec (no fetch) when the scan does NOT +/// support pushdown. The filter should still move below the sort, landing +/// as a new FilterExec between SortExec and DataSourceExec. +#[test] +fn test_filter_pushdown_through_sort_no_scan_support() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + )); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// Multiple conjunctive filters above a plain SortExec should all be +/// pushed below the sort as a single FilterExec. +#[test] +fn test_multiple_filters_pushdown_through_sort() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + )); + // WHERE a = 'foo' AND b = 'bar' + let predicate = Arc::new(BinaryExpr::new( + col_lit_predicate("a", "foo", &schema()), + Operator::And, + col_lit_predicate("b", "bar", &schema()), + )) as Arc; + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo AND b@1 = bar + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - FilterExec: a@0 = foo AND b@1 = bar + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// FilterExec above a SortExec with fetch (TopK) must NOT be pushed below, +/// because limiting happens after filtering — changing the order would alter +/// the result set. +#[test] +fn test_filter_not_pushed_through_sort_with_fetch() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + ) + .with_fetch(Some(10)), + ); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - FilterExec: a@0 = foo + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// FilterExec above a SortExec with fetch (TopK) must NOT be pushed below, +/// because limiting happens after filtering — changing the order would alter +/// the result set. +#[test] +fn test_filter_pushed_through_sort_with_fetch() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + ) + .with_fetch(Some(10)), + ); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); +} + +/// FilterExec with a projection above SortExec. The filter should be pushed +/// below the sort, and the projection should be preserved as a +/// ProjectionExec on top. +#[test] +fn test_filter_with_projection_pushdown_through_sort() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + )); + // FilterExec: b = 'bar', projection=[a] (only output column a) + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, sort) + .apply_projection(Some(vec![0])) + .unwrap() + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar, projection=[a@0] + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - ProjectionExec: expr=[a@0 as a] + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - FilterExec: b@1 = bar + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// SortExec with preserve_partitioning=true should keep that setting after +/// filters are pushed below it. +#[test] +fn test_filter_pushdown_through_sort_preserves_partitioning() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + ) + .with_preserve_partitioning(true), + ); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, sort).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// FilterExec **with a fetch limit** above a plain SortExec. When the filter +/// is pushed below the sort the fetch should be propagated to the SortExec +/// (turning it into a TopK). +#[test] +fn test_filter_with_fetch_pushdown_through_sort() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", &schema()).unwrap(), + )]) + .unwrap(), + scan, + )); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, sort) + .with_fetch(Some(10)) + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, fetch=10 + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_pushdown_through_sort_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let sort = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &schema()).unwrap(), + SortOptions::new(true, false), // descending, nulls_last + )]) + .unwrap(), + scan, + )); + // FilterExec: b = 'bar', projection=[a] (only output column a) + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, sort) + .apply_projection(Some(vec![0])) + .unwrap() + .build() + .unwrap(), + ); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), false), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar, projection=[a@0] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - ProjectionExec: expr=[a@0 as a] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - FilterExec: b@1 = bar + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs deleted file mode 100644 index de61149508904..0000000000000 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ /dev/null @@ -1,2335 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::{Arc, LazyLock}; - -use arrow::{ - array::record_batch, - datatypes::{DataType, Field, Schema, SchemaRef}, - util::pretty::pretty_format_batches, -}; -use arrow_schema::SortOptions; -use datafusion::{ - assert_batches_eq, - logical_expr::Operator, - physical_plan::{ - expressions::{BinaryExpr, Column, Literal}, - PhysicalExpr, - }, - prelude::{ParquetReadOptions, SessionConfig, SessionContext}, - scalar::ScalarValue, -}; -use datafusion_catalog::memory::DataSourceExec; -use datafusion_common::config::ConfigOptions; -use datafusion_datasource::{ - file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, PartitionedFile, -}; -use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::ScalarUDF; -use datafusion_functions::math::random::RandomFunc; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_physical_expr::{ - aggregate::AggregateExprBuilder, Partitioning, ScalarFunctionExpr, -}; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::{ - filter_pushdown::FilterPushdown, PhysicalOptimizerRule, -}; -use datafusion_physical_plan::{ - aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, - coalesce_batches::CoalesceBatchesExec, - coalesce_partitions::CoalescePartitionsExec, - collect, - filter::FilterExec, - repartition::RepartitionExec, - sorts::sort::SortExec, - ExecutionPlan, -}; - -use datafusion_physical_plan::union::UnionExec; -use futures::StreamExt; -use object_store::{memory::InMemory, ObjectStore}; -use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; - -use crate::physical_optimizer::filter_pushdown::util::TestSource; - -mod util; - -#[test] -fn test_pushdown_into_scan() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[test] -fn test_pushdown_volatile_functions_not_allowed() { - // Test that we do not push down filters with volatile functions - // Use random() as an example of a volatile function - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let cfg = Arc::new(ConfigOptions::default()); - let predicate = Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema("a", &schema()).unwrap()), - Operator::Eq, - Arc::new( - ScalarFunctionExpr::try_new( - Arc::new(ScalarUDF::from(RandomFunc::new())), - vec![], - &schema(), - cfg, - ) - .unwrap(), - ), - )) as Arc; - let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); - // expect the filter to not be pushed down - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = random() - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = random() - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - ", - ); -} - -/// Show that we can use config options to determine how to do pushdown. -#[test] -fn test_pushdown_into_scan_with_config_options() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; - - let mut cfg = ConfigOptions::default(); - insta::assert_snapshot!( - OptimizationTest::new( - Arc::clone(&plan), - FilterPushdown::new(), - false - ), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - " - ); - - cfg.execution.parquet.pushdown_filters = true; - insta::assert_snapshot!( - OptimizationTest::new( - plan, - FilterPushdown::new(), - true - ), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[tokio::test] -async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8View, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8View, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8View, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("e", DataType::Utf8View, false), - Field::new("f", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create HashJoinExec - let on = vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )]; - let join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - let join_schema = join.schema(); - - // Finally let's add a SortExec on the outside to test pushdown of dynamic filters - let sort_expr = - PhysicalSortExpr::new(col("e", &join_schema).unwrap(), SortOptions::default()); - let plan = Arc::new( - SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), join) - .with_fetch(Some(2)), - ) as Arc; - - let mut config = ConfigOptions::default(); - config.optimizer.enable_dynamic_filter_pushdown = true; - config.execution.parquet.pushdown_filters = true; - - // Apply the FilterPushdown optimizer rule - let plan = FilterPushdown::new_post_optimization() - .optimize(Arc::clone(&plan), &config) - .unwrap(); - - // Test that filters are pushed down correctly to each side of the join - insta::assert_snapshot!( - format_plan_for_test(&plan), - @r" - - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false] - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] AND DynamicFilter [ empty ] - " - ); - - // Put some data through the plan to check that the filter is updated to reflect the TopK state - let session_ctx = SessionContext::new_with_config(SessionConfig::new()); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); - // Iterate one batch - stream.next().await.unwrap().unwrap(); - - // Test that filters are pushed down correctly to each side of the join - insta::assert_snapshot!( - format_plan_for_test(&plan), - @r" - - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] - " - ); -} - -// Test both static and dynamic filter pushdown in HashJoinExec. -// Note that static filter pushdown is rare: it should have already happened in the logical optimizer phase. -// However users may manually construct plans that could result in a FilterExec -> HashJoinExec -> Scan setup. -// Dynamic filters arise in cases such as nested inner joins or TopK -> HashJoinExec -> Scan setups. -#[tokio::test] -async fn test_static_filter_pushdown_through_hash_join() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8View, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8View, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8View, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("e", DataType::Utf8View, false), - Field::new("f", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create HashJoinExec - let on = vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )]; - let join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - // Create filters that can be pushed down to different sides - // We need to create filters in the context of the join output schema - let join_schema = join.schema(); - - // Filter on build side column: a = 'aa' - let left_filter = col_lit_predicate("a", "aa", &join_schema); - // Filter on probe side column: e = 'ba' - let right_filter = col_lit_predicate("e", "ba", &join_schema); - // Filter that references both sides: a = d (should not be pushed down) - let cross_filter = Arc::new(BinaryExpr::new( - col("a", &join_schema).unwrap(), - Operator::Eq, - col("d", &join_schema).unwrap(), - )) as Arc; - - let filter = - Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); - let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); - let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) - as Arc; - - // Test that filters are pushed down correctly to each side of the join - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = d@3 - - FilterExec: e@4 = ba - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = d@3 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba - " - ); - - // Test left join - filters should NOT be pushed down - let join = Arc::new( - HashJoinExec::try_new( - TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .build(), - TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .build(), - vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )], - None, - &JoinType::Left, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - let join_schema = join.schema(); - let filter = col_lit_predicate("a", "aa", &join_schema); - let plan = - Arc::new(FilterExec::try_new(filter, join).unwrap()) as Arc; - - // Test that filters are NOT pushed down for left join - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true - " - ); -} - -#[test] -fn test_filter_collapse() { - // filter should be pushed down into the parquet scan with two filters - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate1 = col_lit_predicate("a", "foo", &schema()); - let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); - let predicate2 = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar - " - ); -} - -#[test] -fn test_filter_with_projection() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let projection = vec![1, 0]; - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new( - FilterExec::try_new(predicate, Arc::clone(&scan)) - .unwrap() - .with_projection(Some(projection)) - .unwrap(), - ); - - // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo, projection=[b@1, a@0] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - ProjectionExec: expr=[b@1 as b, a@0 as a] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - ", - ); - - // add a test where the filter is on a column that isn't included in the output - let projection = vec![1]; - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new( - FilterExec::try_new(predicate, scan) - .unwrap() - .with_projection(Some(projection)) - .unwrap(), - ); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(),true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo, projection=[b@1] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - ProjectionExec: expr=[b@1 as b] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[test] -fn test_push_down_through_transparent_nodes() { - // expect the predicate to be pushed down into the DataSource - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); - let predicate = col_lit_predicate("a", "foo", &schema()); - let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); - let repartition = Arc::new( - RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), - ); - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(),true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar - " - ); -} - -#[test] -fn test_pushdown_through_aggregates_on_grouping_columns() { - // Test that filters on grouping columns can be pushed through AggregateExec. - // This test has two filters: - // 1. An inner filter (a@0 = foo) below the aggregate - gets pushed to DataSource - // 2. An outer filter (b@1 = bar) above the aggregate - also gets pushed through because 'b' is a grouping column - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); - - let filter = Arc::new( - FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), - ); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - let group_by = PhysicalGroupBy::new_single(vec![ - (col("a", &schema()).unwrap(), "a".to_string()), - (col("b", &schema()).unwrap(), "b".to_string()), - ]); - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - filter, - schema(), - ) - .unwrap(), - ); - - let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); - - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); - - // Both filters should be pushed down to the DataSource since both reference grouping columns - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) - - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=Sorted - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar - " - ); -} - -/// Test various combinations of handling of child pushdown results -/// in an ExecutionPlan in combination with support/not support in a DataSource. -#[test] -fn test_node_handles_child_pushdown_result() { - // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource - // and no FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - ", - ); - - // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource - // and a FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(false).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - output: - Ok: - - TestInsertExec { inject_filter: false } - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - ", - ); - - // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource - // and no FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(false).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: false } - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - output: - Ok: - - TestInsertExec { inject_filter: false } - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - ", - ); -} - -#[tokio::test] -async fn test_topk_dynamic_filter_pushdown() { - let batches = vec![ - record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["bd", "bc"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap(), - record_batch!( - ("a", Utf8, ["ac", "ad"]), - ("b", Utf8, ["bb", "ba"]), - ("c", Float64, [2.0, 1.0]) - ) - .unwrap(), - ]; - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - let plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("b", &schema()).unwrap(), - SortOptions::new(true, false), // descending, nulls_first - )]) - .unwrap(), - Arc::clone(&scan), - ) - .with_fetch(Some(1)), - ) as Arc; - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); - - // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(plan, &config) - .unwrap(); - let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); - // Iterate one batch - stream.next().await.unwrap().unwrap(); - // Now check what our filter looks like - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false], filter=[b@1 > bd] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 > bd ] - " - ); -} - -#[tokio::test] -async fn test_topk_dynamic_filter_pushdown_multi_column_sort() { - let batches = vec![ - // We are going to do ORDER BY b ASC NULLS LAST, a DESC - // And we put the values in such a way that the first batch will fill the TopK - // and we skip the second batch. - record_batch!( - ("a", Utf8, ["ac", "ad"]), - ("b", Utf8, ["bb", "ba"]), - ("c", Float64, [2.0, 1.0]) - ) - .unwrap(), - record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["bc", "bd"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap(), - ]; - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - let plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr::new( - col("b", &schema()).unwrap(), - SortOptions::default().asc().nulls_last(), - ), - PhysicalSortExpr::new( - col("a", &schema()).unwrap(), - SortOptions::default().desc().nulls_first(), - ), - ]) - .unwrap(), - Arc::clone(&scan), - ) - .with_fetch(Some(2)), - ) as Arc; - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); - - // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(plan, &config) - .unwrap(); - let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); - // Iterate one batch - let res = stream.next().await.unwrap().unwrap(); - #[rustfmt::skip] - let expected = [ - "+----+----+-----+", - "| a | b | c |", - "+----+----+-----+", - "| ad | ba | 1.0 |", - "| ac | bb | 2.0 |", - "+----+----+-----+", - ]; - assert_batches_eq!(expected, &[res]); - // Now check what our filter looks like - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false], filter=[b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac) ] - " - ); - // There should be no more batches - assert!(stream.next().await.is_none()); -} - -#[tokio::test] -async fn test_topk_filter_passes_through_coalesce_partitions() { - // Create multiple batches for different partitions - let batches = vec![ - record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["bd", "bc"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap(), - record_batch!( - ("a", Utf8, ["ac", "ad"]), - ("b", Utf8, ["bb", "ba"]), - ("c", Float64, [2.0, 1.0]) - ) - .unwrap(), - ]; - - // Create a source that supports all batches - let source = Arc::new(TestSource::new(true, batches)); - - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test://").unwrap(), - Arc::clone(&schema()), - source, - ) - .with_file_groups(vec![ - // Partition 0 - FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), - // Partition 1 - FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), - ]) - .build(); - - let scan = DataSourceExec::from_data_source(base_config); - - // Add CoalescePartitionsExec to merge the two partitions - let coalesce = Arc::new(CoalescePartitionsExec::new(scan)) as Arc; - - // Add SortExec with TopK - let plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("b", &schema()).unwrap(), - SortOptions::new(true, false), - )]) - .unwrap(), - coalesce, - ) - .with_fetch(Some(1)), - ) as Arc; - - // Test optimization - the filter SHOULD pass through CoalescePartitionsExec - // if it properly implements from_children (not all_unsupported) - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); -} - -#[tokio::test] -async fn test_topk_filter_passes_through_coalesce_batches() { - let batches = vec![ - record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["bd", "bc"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap(), - record_batch!( - ("a", Utf8, ["ac", "ad"]), - ("b", Utf8, ["bb", "ba"]), - ("c", Float64, [2.0, 1.0]) - ) - .unwrap(), - ]; - - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - - let coalesce_batches = - Arc::new(CoalesceBatchesExec::new(scan, 1024)) as Arc; - - // Add SortExec with TopK - let plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("b", &schema()).unwrap(), - SortOptions::new(true, false), - )]) - .unwrap(), - coalesce_batches, - ) - .with_fetch(Some(1)), - ) as Arc; - - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] - - CoalesceBatchesExec: target_batch_size=1024 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); -} - -#[tokio::test] -async fn test_hashjoin_dynamic_filter_pushdown() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("e", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create HashJoinExec with dynamic filter - let on = vec![ - ( - col("a", &build_side_schema).unwrap(), - col("a", &probe_side_schema).unwrap(), - ), - ( - col("b", &build_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ), - ]; - let plan = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::CollectLeft, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ) as Arc; - - // expect the predicate to be pushed down into the probe side DataSource - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true - output: - Ok: - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - ", - ); - - // Actually apply the optimization to the plan and execute to see the filter in action - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - config.optimizer.enable_dynamic_filter_pushdown = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(plan, &config) - .unwrap(); - - // Test for https://github.com/apache/datafusion/pull/17371: dynamic filter linking survives `with_new_children` - let children = plan.children().into_iter().map(Arc::clone).collect(); - let plan = plan.with_new_children(children).unwrap(); - - let config = SessionConfig::new().with_batch_size(10); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); - // Iterate one batch - stream.next().await.unwrap().unwrap(); - - // Now check what our filter looks like - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] - " - ); -} - -#[tokio::test] -async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Rough sketch of the MRE we're trying to recreate: - // COPY (select i as k from generate_series(1, 10000000) as t(i)) - // TO 'test_files/scratch/push_down_filter/t1.parquet' - // STORED AS PARQUET; - // COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) - // TO 'test_files/scratch/push_down_filter/t2.parquet' - // STORED AS PARQUET; - // create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; - // create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; - // explain - // select * - // from t1 - // join t2 on t1.k = t2.k; - // +---------------+------------------------------------------------------------+ - // | plan_type | plan | - // +---------------+------------------------------------------------------------+ - // | physical_plan | ┌───────────────────────────┐ | - // | | │ CoalesceBatchesExec │ | - // | | │ -------------------- │ | - // | | │ target_batch_size: │ | - // | | │ 8192 │ | - // | | └─────────────┬─────────────┘ | - // | | ┌─────────────┴─────────────┐ | - // | | │ HashJoinExec │ | - // | | │ -------------------- ├──────────────┐ | - // | | │ on: (k = k) │ │ | - // | | └─────────────┬─────────────┘ │ | - // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | - // | | │ CoalesceBatchesExec ││ CoalesceBatchesExec │ | - // | | │ -------------------- ││ -------------------- │ | - // | | │ target_batch_size: ││ target_batch_size: │ | - // | | │ 8192 ││ 8192 │ | - // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | - // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | - // | | │ RepartitionExec ││ RepartitionExec │ | - // | | │ -------------------- ││ -------------------- │ | - // | | │ partition_count(in->out): ││ partition_count(in->out): │ | - // | | │ 12 -> 12 ││ 12 -> 12 │ | - // | | │ ││ │ | - // | | │ partitioning_scheme: ││ partitioning_scheme: │ | - // | | │ Hash([k@0], 12) ││ Hash([k@0], 12) │ | - // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | - // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | - // | | │ DataSourceExec ││ DataSourceExec │ | - // | | │ -------------------- ││ -------------------- │ | - // | | │ files: 12 ││ files: 12 │ | - // | | │ format: parquet ││ format: parquet │ | - // | | │ ││ predicate: true │ | - // | | └───────────────────────────┘└───────────────────────────┘ | - // | | | - // +---------------+------------------------------------------------------------+ - - // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("e", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create RepartitionExec nodes for both sides with hash partitioning on join keys - let partition_count = 12; - - // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec - let build_hash_exprs = vec![ - col("a", &build_side_schema).unwrap(), - col("b", &build_side_schema).unwrap(), - ]; - let build_repartition = Arc::new( - RepartitionExec::try_new( - build_scan, - Partitioning::Hash(build_hash_exprs, partition_count), - ) - .unwrap(), - ); - let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); - - // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec - let probe_hash_exprs = vec![ - col("a", &probe_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ]; - let probe_repartition = Arc::new( - RepartitionExec::try_new( - Arc::clone(&probe_scan), - Partitioning::Hash(probe_hash_exprs, partition_count), - ) - .unwrap(), - ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); - - // Create HashJoinExec with partitioned inputs - let on = vec![ - ( - col("a", &build_side_schema).unwrap(), - col("a", &probe_side_schema).unwrap(), - ), - ( - col("b", &build_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ), - ]; - let hash_join = Arc::new( - HashJoinExec::try_new( - build_coalesce, - probe_coalesce, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; - // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; - // Add a sort for deterministic output - let plan = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("a", &probe_side_schema).unwrap(), - SortOptions::new(true, false), // descending, nulls_first - )]) - .unwrap(), - cp, - )) as Arc; - - // expect the predicate to be pushed down into the probe side DataSource - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); - - // Actually apply the optimization to the plan and execute to see the filter in action - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - config.optimizer.enable_dynamic_filter_pushdown = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(plan, &config) - .unwrap(); - let config = SessionConfig::new().with_batch_size(10); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) - .await - .unwrap(); - - // Now check what our filter looks like - #[cfg(not(feature = "force_hash_collisions"))] - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ] - " - ); - - #[cfg(feature = "force_hash_collisions")] - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] - " - ); - - let result = format!("{}", pretty_format_batches(&batches).unwrap()); - - let probe_scan_metrics = probe_scan.metrics().unwrap(); - - // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. - // The number of output rows from the probe side scan should stay consistent across executions. - // Issue: https://github.com/apache/datafusion/issues/17451 - assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); - - insta::assert_snapshot!( - result, - @r" - +----+----+-----+----+----+-----+ - | a | b | c | a | b | e | - +----+----+-----+----+----+-----+ - | ab | bb | 2.0 | ab | bb | 2.0 | - | aa | ba | 1.0 | aa | ba | 1.0 | - +----+----+-----+----+----+-----+ - ", - ); -} - -#[tokio::test] -async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("e", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create RepartitionExec nodes for both sides with hash partitioning on join keys - let partition_count = 12; - - // Probe side: DataSource -> RepartitionExec(Hash) -> CoalesceBatchesExec - let probe_hash_exprs = vec![ - col("a", &probe_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ]; - let probe_repartition = Arc::new( - RepartitionExec::try_new( - Arc::clone(&probe_scan), - Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide - ) - .unwrap(), - ); - let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); - - let on = vec![ - ( - col("a", &build_side_schema).unwrap(), - col("a", &probe_side_schema).unwrap(), - ), - ( - col("b", &build_side_schema).unwrap(), - col("b", &probe_side_schema).unwrap(), - ), - ]; - let hash_join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_coalesce, - on, - None, - &JoinType::Inner, - None, - PartitionMode::CollectLeft, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - // Top-level CoalesceBatchesExec - let cb = - Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; - // Top-level CoalescePartitionsExec - let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; - // Add a sort for deterministic output - let plan = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new( - col("a", &probe_side_schema).unwrap(), - SortOptions::new(true, false), // descending, nulls_first - )]) - .unwrap(), - cp, - )) as Arc; - - // expect the predicate to be pushed down into the probe side DataSource - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true - output: - Ok: - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - " - ); - - // Actually apply the optimization to the plan and execute to see the filter in action - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - config.optimizer.enable_dynamic_filter_pushdown = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(plan, &config) - .unwrap(); - let config = SessionConfig::new().with_batch_size(10); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) - .await - .unwrap(); - - // Now check what our filter looks like - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] - - CoalescePartitionsExec - - CoalesceBatchesExec: target_batch_size=8192 - - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - CoalesceBatchesExec: target_batch_size=8192 - - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] - " - ); - - let result = format!("{}", pretty_format_batches(&batches).unwrap()); - - let probe_scan_metrics = probe_scan.metrics().unwrap(); - - // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. - // The number of output rows from the probe side scan should stay consistent across executions. - // Issue: https://github.com/apache/datafusion/issues/17451 - assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); - - insta::assert_snapshot!( - result, - @r" - +----+----+-----+----+----+-----+ - | a | b | c | a | b | e | - +----+----+-----+----+----+-----+ - | ab | bb | 2.0 | ab | bb | 2.0 | - | aa | ba | 1.0 | aa | ba | 1.0 | - +----+----+-----+----+----+-----+ - ", - ); -} - -#[tokio::test] -async fn test_nested_hashjoin_dynamic_filter_pushdown() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Create test data for three tables: t1, t2, t3 - // t1: small table with limited values (will be build side of outer join) - let t1_batches = - vec![ - record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), - ]; - let t1_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("x", DataType::Float64, false), - ])); - let t1_scan = TestScanBuilder::new(Arc::clone(&t1_schema)) - .with_support(true) - .with_batches(t1_batches) - .build(); - - // t2: larger table (will be probe side of inner join, build side of outer join) - let t2_batches = vec![record_batch!( - ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), - ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), - ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) - ) - .unwrap()]; - let t2_schema = Arc::new(Schema::new(vec![ - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Utf8, false), - Field::new("y", DataType::Float64, false), - ])); - let t2_scan = TestScanBuilder::new(Arc::clone(&t2_schema)) - .with_support(true) - .with_batches(t2_batches) - .build(); - - // t3: largest table (will be probe side of inner join) - let t3_batches = vec![record_batch!( - ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), - ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - ) - .unwrap()]; - let t3_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("z", DataType::Float64, false), - ])); - let t3_scan = TestScanBuilder::new(Arc::clone(&t3_schema)) - .with_support(true) - .with_batches(t3_batches) - .build(); - - // Create nested join structure: - // Join (t1.a = t2.b) - // / \ - // t1 Join(t2.c = t3.d) - // / \ - // t2 t3 - - // First create inner join: t2.c = t3.d - let inner_join_on = - vec![(col("c", &t2_schema).unwrap(), col("d", &t3_schema).unwrap())]; - let inner_join = Arc::new( - HashJoinExec::try_new( - t2_scan, - t3_scan, - inner_join_on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - // Then create outer join: t1.a = t2.b (from inner join result) - let outer_join_on = vec![( - col("a", &t1_schema).unwrap(), - col("b", &inner_join.schema()).unwrap(), - )]; - let outer_join = Arc::new( - HashJoinExec::try_new( - t1_scan, - inner_join as Arc, - outer_join_on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ) as Arc; - - // Test that dynamic filters are pushed down correctly through nested joins - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&outer_join), FilterPushdown::new_post_optimization(), true), - @r" - OptimizationTest: - input: - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true - output: - Ok: - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] - ", - ); - - // Execute the plan to verify the dynamic filters are properly updated - let mut config = ConfigOptions::default(); - config.execution.parquet.pushdown_filters = true; - config.optimizer.enable_dynamic_filter_pushdown = true; - let plan = FilterPushdown::new_post_optimization() - .optimize(outer_join, &config) - .unwrap(); - let config = SessionConfig::new().with_batch_size(10); - let session_ctx = SessionContext::new_with_config(config); - session_ctx.register_object_store( - ObjectStoreUrl::parse("test://").unwrap().as_ref(), - Arc::new(InMemory::new()), - ); - let state = session_ctx.state(); - let task_ctx = state.task_ctx(); - let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); - // Execute to populate the dynamic filters - stream.next().await.unwrap().unwrap(); - - // Verify that both the inner and outer join have updated dynamic filters - insta::assert_snapshot!( - format!("{}", format_plan_for_test(&plan)), - @r" - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb ] - " - ); -} - -#[tokio::test] -async fn test_hashjoin_parent_filter_pushdown() { - use datafusion_common::JoinType; - use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - - // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; - let build_side_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ])); - let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) - .with_support(true) - .with_batches(build_batches) - .build(); - - // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; - let probe_side_schema = Arc::new(Schema::new(vec![ - Field::new("d", DataType::Utf8, false), - Field::new("e", DataType::Utf8, false), - Field::new("f", DataType::Float64, false), - ])); - let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) - .with_support(true) - .with_batches(probe_batches) - .build(); - - // Create HashJoinExec - let on = vec![( - col("a", &build_side_schema).unwrap(), - col("d", &probe_side_schema).unwrap(), - )]; - let join = Arc::new( - HashJoinExec::try_new( - build_scan, - probe_scan, - on, - None, - &JoinType::Inner, - None, - PartitionMode::Partitioned, - datafusion_common::NullEquality::NullEqualsNothing, - ) - .unwrap(), - ); - - // Create filters that can be pushed down to different sides - // We need to create filters in the context of the join output schema - let join_schema = join.schema(); - - // Filter on build side column: a = 'aa' - let left_filter = col_lit_predicate("a", "aa", &join_schema); - // Filter on probe side column: e = 'ba' - let right_filter = col_lit_predicate("e", "ba", &join_schema); - // Filter that references both sides: a = d (should not be pushed down) - let cross_filter = Arc::new(BinaryExpr::new( - col("a", &join_schema).unwrap(), - Operator::Eq, - col("d", &join_schema).unwrap(), - )) as Arc; - - let filter = - Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); - let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); - let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) - as Arc; - - // Test that filters are pushed down correctly to each side of the join - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = d@3 - - FilterExec: e@4 = ba - - FilterExec: a@0 = aa - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = d@3 - - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba - " - ); -} - -/// Integration test for dynamic filter pushdown with TopK. -/// We use an integration test because there are complex interactions in the optimizer rules -/// that the unit tests applying a single optimizer rule do not cover. -#[tokio::test] -async fn test_topk_dynamic_filter_pushdown_integration() { - let store = Arc::new(InMemory::new()) as Arc; - let mut cfg = SessionConfig::new(); - cfg.options_mut().execution.parquet.pushdown_filters = true; - cfg.options_mut().execution.parquet.max_row_group_size = 128; - let ctx = SessionContext::new_with_config(cfg); - ctx.register_object_store( - ObjectStoreUrl::parse("memory://").unwrap().as_ref(), - Arc::clone(&store), - ); - ctx.sql( - r" -COPY ( - SELECT 1372708800 + value AS t - FROM generate_series(0, 99999) - ORDER BY t - ) TO 'memory:///1.parquet' -STORED AS PARQUET; - ", - ) - .await - .unwrap() - .collect() - .await - .unwrap(); - - // Register the file with the context - ctx.register_parquet( - "topk_pushdown", - "memory:///1.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - // Create a TopK query that will use dynamic filter pushdown - let df = ctx - .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t LIMIT 10;") - .await - .unwrap(); - let batches = df.collect().await.unwrap(); - let explain = format!("{}", pretty_format_batches(&batches).unwrap()); - - assert!(explain.contains("output_rows=128")); // Read 1 row group - assert!(explain.contains("t@0 < 1372708809")); // Dynamic filter was applied - assert!( - explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99872"), - "{explain}" - ); - // Pushdown pruned most rows -} - -#[test] -fn test_filter_pushdown_through_union() { - let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); - let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); - - let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); - - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - UnionExec - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - UnionExec - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -/// Schema: -/// a: String -/// b: String -/// c: f64 -static TEST_SCHEMA: LazyLock = LazyLock::new(|| { - let fields = vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ]; - Arc::new(Schema::new(fields)) -}); - -fn schema() -> SchemaRef { - Arc::clone(&TEST_SCHEMA) -} - -/// Returns a predicate that is a binary expression col = lit -fn col_lit_predicate( - column_name: &str, - scalar_value: impl Into, - schema: &Schema, -) -> Arc { - let scalar_value = scalar_value.into(); - Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema(column_name, schema).unwrap()), - Operator::Eq, - Arc::new(Literal::new(scalar_value)), - )) -} - -#[tokio::test] -async fn test_aggregate_filter_pushdown() { - // Test that filters can pass through AggregateExec even with aggregate functions - // when the filter references grouping columns - // Simulates: SELECT a, COUNT(b) FROM table WHERE a = 'x' GROUP BY a - - let batches = - vec![ - record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), - ]; - - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - - // Create an aggregate: GROUP BY a with COUNT(b) - let group_by = PhysicalGroupBy::new_single(vec![( - col("a", &schema()).unwrap(), - "a".to_string(), - )]); - - // Add COUNT aggregate - let count_expr = - AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema()).unwrap()]) - .schema(schema()) - .alias("count") - .build() - .unwrap(); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Partial, - group_by, - vec![count_expr.into()], // Has aggregate function - vec![None], // No filter on the aggregate function - Arc::clone(&scan), - schema(), - ) - .unwrap(), - ); - - // Add a filter on the grouping column 'a' - let predicate = col_lit_predicate("a", "x", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()) - as Arc; - - // Even with aggregate functions, filter on grouping column should be pushed through - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = x - - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count], ordering_mode=Sorted - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = x - " - ); -} - -#[tokio::test] -async fn test_no_pushdown_filter_on_aggregate_result() { - // Test that filters on aggregate results (not grouping columns) are NOT pushed through - // SELECT a, COUNT(b) as cnt FROM table GROUP BY a HAVING cnt > 5 - // The filter on 'cnt' cannot be pushed down because it's an aggregate result - - let batches = - vec![ - record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), - ]; - - let scan = TestScanBuilder::new(schema()) - .with_support(true) - .with_batches(batches) - .build(); - - // Create an aggregate: GROUP BY a with COUNT(b) - let group_by = PhysicalGroupBy::new_single(vec![( - col("a", &schema()).unwrap(), - "a".to_string(), - )]); - - // Add COUNT aggregate - let count_expr = - AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema()).unwrap()]) - .schema(schema()) - .alias("count") - .build() - .unwrap(); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Partial, - group_by, - vec![count_expr.into()], - vec![None], - Arc::clone(&scan), - schema(), - ) - .unwrap(), - ); - - // Add a filter on the aggregate output column - // This simulates filtering on COUNT result, which should NOT be pushed through - let agg_schema = aggregate.schema(); - let predicate = Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema("count[count]", &agg_schema).unwrap()), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int64(Some(5)))), - )); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()) - as Arc; - - // The filter should NOT be pushed through the aggregate since it's on an aggregate result - insta::assert_snapshot!( - OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: count[count]@1 > 5 - - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: count[count]@1 > 5 - - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - " - ); -} - -#[test] -fn test_pushdown_filter_on_non_first_grouping_column() { - // Test that filters on non-first grouping columns are still pushed down - // SELECT a, b, count(*) as cnt FROM table GROUP BY a, b HAVING b = 'bar' - // The filter is on 'b' (second grouping column), should push down - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - - let group_by = PhysicalGroupBy::new_single(vec![ - (col("a", &schema()).unwrap(), "a".to_string()), - (col("b", &schema()).unwrap(), "b".to_string()), - ]); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - scan, - schema(), - ) - .unwrap(), - ); - - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([1]) - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar - " - ); -} - -#[test] -fn test_no_pushdown_grouping_sets_filter_on_missing_column() { - // Test that filters on columns missing from some grouping sets are NOT pushed through - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - - // Create GROUPING SETS with (a, b) and (b) - let group_by = PhysicalGroupBy::new( - vec![ - (col("a", &schema()).unwrap(), "a".to_string()), - (col("b", &schema()).unwrap(), "b".to_string()), - ], - vec![ - ( - Arc::new(Literal::new(ScalarValue::Utf8(None))), - "a".to_string(), - ), - ( - Arc::new(Literal::new(ScalarValue::Utf8(None))), - "b".to_string(), - ), - ], - vec![ - vec![false, false], // (a, b) - both present - vec![true, false], // (b) - a is NULL, b present - ], - ); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - scan, - schema(), - ) - .unwrap(), - ); - - // Filter on column 'a' which is missing in the second grouping set, should not be pushed down - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = foo - - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - " - ); -} - -#[test] -fn test_pushdown_grouping_sets_filter_on_common_column() { - // Test that filters on columns present in ALL grouping sets ARE pushed through - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - - // Create GROUPING SETS with (a, b) and (b) - let group_by = PhysicalGroupBy::new( - vec![ - (col("a", &schema()).unwrap(), "a".to_string()), - (col("b", &schema()).unwrap(), "b".to_string()), - ], - vec![ - ( - Arc::new(Literal::new(ScalarValue::Utf8(None))), - "a".to_string(), - ), - ( - Arc::new(Literal::new(ScalarValue::Utf8(None))), - "b".to_string(), - ), - ], - vec![ - vec![false, false], // (a, b) - both present - vec![true, false], // (b) - a is NULL, b present - ], - ); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - scan, - schema(), - ) - .unwrap(), - ); - - // Filter on column 'b' which is present in all grouping sets will be pushed down - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt], ordering_mode=PartiallySorted([1]) - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar - " - ); -} - -#[test] -fn test_pushdown_with_empty_group_by() { - // Test that filters can be pushed down when GROUP BY is empty (no grouping columns) - // SELECT count(*) as cnt FROM table WHERE a = 'foo' - // There are no grouping columns, so the filter should still push down - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - - // Empty GROUP BY - no grouping columns - let group_by = PhysicalGroupBy::new_single(vec![]); - - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - scan, - schema(), - ) - .unwrap(), - ); - - // Filter on 'a' - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); - - // The filter should be pushed down even with empty GROUP BY - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - AggregateExec: mode=Final, gby=[], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - AggregateExec: mode=Final, gby=[], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[test] -fn test_pushdown_with_computed_grouping_key() { - // Test filter pushdown with computed grouping expression - // SELECT (c + 1.0) as c_plus_1, count(*) FROM table WHERE c > 5.0 GROUP BY (c + 1.0) - - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let predicate = Arc::new(BinaryExpr::new( - col("c", &schema()).unwrap(), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Float64(Some(5.0)))), - )) as Arc; - let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - - let c_plus_one = Arc::new(BinaryExpr::new( - col("c", &schema()).unwrap(), - Operator::Plus, - Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))), - )) as Arc; - - let group_by = - PhysicalGroupBy::new_single(vec![(c_plus_one, "c_plus_1".to_string())]); - - let plan = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - filter, - schema(), - ) - .unwrap(), - ); - - // The filter should be pushed down because 'c' is extracted from the grouping expression (c + 1.0) - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown::new(), true), - @r" - OptimizationTest: - input: - - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] - - FilterExec: c@2 > 5 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=c@2 > 5 - " - ); -} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index f9d3a045469e1..050baa9e792e9 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -18,7 +18,6 @@ use insta::assert_snapshot; use std::sync::Arc; use std::{ - any::Any, pin::Pin, task::{Context, Poll}, }; @@ -26,27 +25,28 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{ColumnStatistics, JoinType, ScalarValue, stats::Precision}; use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; +use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; -use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; +use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::utils::JoinFilter; use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; use futures::Stream; @@ -222,6 +222,7 @@ async fn test_join_with_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -231,7 +232,6 @@ async fn test_join_with_swap() { .unwrap(); let swapping_projection = optimized_join - .as_any() .downcast_ref::() .expect("A proj is required to swap columns back to their original order"); @@ -245,7 +245,6 @@ async fn test_join_with_swap() { let swapped_join = swapping_projection .input() - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -284,6 +283,7 @@ async fn test_left_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -293,7 +293,6 @@ async fn test_left_join_no_swap() { .unwrap(); let swapped_join = optimized_join - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -333,6 +332,7 @@ async fn test_join_with_swap_semi() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -342,12 +342,9 @@ async fn test_join_with_swap_semi() { .optimize(Arc::new(join), &ConfigOptions::new()) .unwrap(); - let swapped_join = optimized_join - .as_any() - .downcast_ref::() - .expect( - "A proj is not required to swap columns back to their original order", - ); + let swapped_join = optimized_join.downcast_ref::().expect( + "A proj is not required to swap columns back to their original order", + ); assert_eq!(swapped_join.schema().fields().len(), 1); assert_eq!( @@ -388,6 +385,7 @@ async fn test_join_with_swap_mark() { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -397,12 +395,9 @@ async fn test_join_with_swap_mark() { .optimize(Arc::new(join), &ConfigOptions::new()) .unwrap(); - let swapped_join = optimized_join - .as_any() - .downcast_ref::() - .expect( - "A proj is not required to swap columns back to their original order", - ); + let swapped_join = optimized_join.downcast_ref::().expect( + "A proj is not required to swap columns back to their original order", + ); assert_eq!(swapped_join.schema().fields().len(), 2); assert_eq!( @@ -461,6 +456,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); let child_schema = child_join.schema(); @@ -478,6 +474,7 @@ async fn test_nested_join_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(); @@ -518,6 +515,7 @@ async fn test_join_no_swap() { None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -527,7 +525,6 @@ async fn test_join_no_swap() { .unwrap(); let swapped_join = optimized_join - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -576,7 +573,6 @@ async fn test_nl_join_with_swap(join_type: JoinType) { .unwrap(); let swapping_projection = optimized_join - .as_any() .downcast_ref::() .expect("A proj is required to swap columns back to their original order"); @@ -590,7 +586,6 @@ async fn test_nl_join_with_swap(join_type: JoinType) { let swapped_join = swapping_projection .input() - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -657,7 +652,6 @@ async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { .unwrap(); let swapped_join = optimized_join - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -745,16 +739,19 @@ async fn test_hash_join_swap_on_joins_with_projections( Some(projection), PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?); let swapped = join .swap_inputs(PartitionMode::Partitioned) .expect("swap_hash_join must support joins with projections"); - let swapped_join = swapped.as_any().downcast_ref::().expect( + let swapped_join = swapped + .downcast_ref::() + .expect( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); - assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped_join.projection.as_deref().unwrap(), &[0_usize]); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col"); Ok(()) @@ -762,7 +759,6 @@ async fn test_hash_join_swap_on_joins_with_projections( fn assert_col_expr(expr: &Arc, name: &str, index: usize) { let col = expr - .as_any() .downcast_ref::() .expect("Projection items should be Column expression"); assert_eq!(col.name(), name); @@ -906,6 +902,7 @@ fn check_join_partition_mode( None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ); @@ -916,18 +913,15 @@ fn check_join_partition_mode( if !is_swapped { let swapped_join = optimized_join - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); assert_eq!(*swapped_join.partition_mode(), expected_mode); } else { let swapping_projection = optimized_join - .as_any() .downcast_ref::() .expect("A proj is required to swap columns back to their original order"); let swapped_join = swapping_projection .input() - .as_any() .downcast_ref::() .expect("The type of the plan should not be changed"); @@ -949,10 +943,10 @@ impl Stream for UnboundedStream { mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if let Some(val) = self.batch_produce { - if val <= self.count { - return Poll::Ready(None); - } + if let Some(val) = self.batch_produce + && val <= self.count + { + return Poll::Ready(None); } self.count += 1; Poll::Ready(Some(Ok(self.batch.clone()))) @@ -970,7 +964,7 @@ impl RecordBatchStream for UnboundedStream { pub struct UnboundedExec { batch_produce: Option, batch: RecordBatch, - cache: PlanProperties, + cache: Arc, } impl UnboundedExec { @@ -986,7 +980,7 @@ impl UnboundedExec { Self { batch_produce, batch, - cache, + cache: Arc::new(cache), } } @@ -1039,11 +1033,7 @@ impl ExecutionPlan for UnboundedExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1069,6 +1059,20 @@ impl ExecutionPlan for UnboundedExec { batch: self.batch.clone(), })) } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } #[derive(Eq, PartialEq, Debug)] @@ -1082,20 +1086,21 @@ pub enum SourceType { pub struct StatisticsExec { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { assert_eq!( - stats.column_statistics.len(), schema.fields().len(), - "if defined, the column statistics vector length should be the number of fields" - ); + stats.column_statistics.len(), + schema.fields().len(), + "if defined, the column statistics vector length should be the number of fields" + ); let cache = Self::compute_properties(Arc::new(schema.clone())); Self { stats, schema: Arc::new(schema), - cache, + cache: Arc::new(cache), } } @@ -1139,11 +1144,7 @@ impl ExecutionPlan for StatisticsExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1166,16 +1167,26 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - - fn partition_statistics(&self, partition: Option) -> Result { - Ok(if partition.is_some() { + fn partition_statistics(&self, partition: Option) -> Result> { + Ok(Arc::new(if partition.is_some() { Statistics::new_unknown(&self.schema) } else { self.stats.clone() - }) + })) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -1553,16 +1564,16 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { None, t.initial_mode, NullEquality::NullEqualsNothing, + false, )?) as _; let optimized_join_plan = JoinSelection::new().optimize(Arc::clone(&join), &ConfigOptions::new())?; // If swap did happen - let projection_added = optimized_join_plan.as_any().is::(); + let projection_added = optimized_join_plan.is::(); let plan = if projection_added { let proj = optimized_join_plan - .as_any() .downcast_ref::() .expect("A proj is required to swap columns back to their original order"); Arc::::clone(proj.input()) @@ -1576,7 +1587,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { join_type, mode, .. - }) = plan.as_any().downcast_ref::() + }) = plan.downcast_ref::() { let left_changed = Arc::ptr_eq(left, &right_exec); let right_changed = Arc::ptr_eq(right, &left_exec); diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index 56d48901f284d..5f9b7e50848fd 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - coalesce_batches_exec, coalesce_partitions_exec, global_limit_exec, local_limit_exec, + coalesce_partitions_exec, global_limit_exec, hash_join_exec, local_limit_exec, sort_exec, sort_preserving_merge_exec, stream_exec, }; @@ -26,17 +26,19 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{col, lit, BinaryExpr}; +use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -87,6 +89,20 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) +} + +fn format_plan(plan: &Arc) -> String { + get_plan_string(plan).join("\n") +} + #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { @@ -94,148 +110,251 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @"StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ); Ok(()) } #[test] -fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( -) -> Result<()> { +fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero() +-> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7 + " + ); Ok(()) } +fn join_on_columns( + left_col: &str, + right_col: &str, +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + vec![( + Arc::new(datafusion_physical_expr::expressions::Column::new( + left_col, 0, + )) as _, + Arc::new(datafusion_physical_expr::expressions::Column::new( + right_col, 0, + )) as _, + )] +} + #[test] -fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( -) -> Result<()> { +fn absorbs_limit_into_hash_join_inner() -> Result<()> { + // HashJoinExec with Inner join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = stream_exec(&schema); - let repartition = repartition_exec(streaming_table)?; - let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter, 8192); - let local_limit = local_limit_exec(coalesce_batches, 5); - let coalesce_partitions = coalesce_partitions_exec(local_limit); - let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "CoalescePartitionsExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join (not pushed to children) + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { +fn absorbs_limit_into_hash_join_right() -> Result<()> { + // HashJoinExec with Right join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = stream_exec(&schema); - let filter = filter_exec(Arc::clone(&schema), streaming_table)?; - let projection = projection_exec(schema, filter)?; - let global_limit = global_limit_exec(projection, 0, Some(5)); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Right)?; + let global_limit = global_limit_exec(hash_join, 0, Some(10)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=10 + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)], fetch=10 + EmptyExec + EmptyExec + " + ); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + Ok(()) +} + +#[test] +fn absorbs_limit_into_hash_join_left() -> Result<()> { + // during probing, then unmatched rows at the end, stopping when limit is reached + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Left)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // Left join now absorbs the limit + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " GlobalLimitExec: skip=0, fetch=5", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + Ok(()) +} + +#[test] +fn absorbs_limit_with_skip_into_hash_join() -> Result<()> { + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 3, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // With skip, GlobalLimit is kept but fetch (skip + limit = 8) is absorbed by the join + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=8 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( -) -> Result<()> { +fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); - let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); - let projection = projection_exec(schema, coalesce_batches)?; + let filter = filter_exec(Arc::clone(&schema), streaming_table)?; + let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -244,8 +363,7 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); - let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); - let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; + let projection = projection_exec(Arc::clone(&schema), streaming_table)?; let repartition = repartition_exec(projection)?; let ordering: LexOrdering = [PhysicalSortExpr { expr: col("c1", &schema)?, @@ -256,31 +374,33 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + SortPreservingMergeExec: [c1@0 ASC] + SortExec: expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + SortPreservingMergeExec: [c1@0 ASC], fetch=5 + SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -295,26 +415,31 @@ fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> R let coalesce_partitions = coalesce_partitions_exec(filter); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + CoalescePartitionsExec + FilterExec: c3@2 > 0 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec: fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + CoalescePartitionsExec: fetch=5 + FilterExec: c3@2 > 0, fetch=5 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -326,20 +451,27 @@ fn merges_local_limit_with_local_limit() -> Result<()> { let child_local_limit = local_limit_exec(empty_exec, 10); let parent_local_limit = local_limit_exec(child_local_limit, 20); - let initial = get_plan_string(&parent_local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " LocalLimitExec: fetch=10", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + LocalLimitExec: fetch=10 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=0, fetch=10 + EmptyExec + " + ); Ok(()) } @@ -351,20 +483,27 @@ fn merges_global_limit_with_global_limit() -> Result<()> { let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); - let initial = get_plan_string(&parent_global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=10, fetch=20", - " GlobalLimitExec: skip=10, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=10, fetch=20 + GlobalLimitExec: skip=10, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -376,20 +515,27 @@ fn merges_global_limit_with_local_limit() -> Result<()> { let local_limit = local_limit_exec(empty_exec, 40); let global_limit = global_limit_exec(local_limit, 20, Some(30)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=20, fetch=30", - " LocalLimitExec: fetch=40", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=20, fetch=30 + LocalLimitExec: fetch=40 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -401,20 +547,173 @@ fn merges_local_limit_with_global_limit() -> Result<()> { let global_limit = global_limit_exec(empty_exec, 20, Some(30)); let local_limit = local_limit_exec(global_limit, 20); - let initial = get_plan_string(&local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " GlobalLimitExec: skip=20, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + GlobalLimitExec: skip=20, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering.into(), empty) + .with_fetch(Some(4)) + .unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=None + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=3 + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn no_limit_preserves_plan_identity() -> Result<()> { + // When there is no limit in the plan, the optimizer should return the + // exact same Arc (pointer-equal) for every node, avoiding unnecessary + // plan reconstruction and property recomputation. + let schema = create_schema(); + + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let plan = filter_exec(Arc::clone(&schema), join)?; + + let optimized = + LimitPushdown::new().optimize(Arc::clone(&plan), &ConfigOptions::new())?; + + assert!( + Arc::ptr_eq(&plan, &optimized), + "Expected optimizer to return the same Arc when no limit is present" + ); + + let optimized = format_plan(&optimized); + insta::assert_snapshot!( + optimized, + @r" + FilterExec: c3@2 > 0 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index ad15d6803413b..c523b4a752a82 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -21,8 +21,8 @@ use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, schema, - TestAggregate, + TestAggregate, build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, + schema, }; use arrow::datatypes::DataType; @@ -34,10 +34,10 @@ use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{self, cast, col}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ + ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, collect, limit::{GlobalLimitExec, LocalLimitExec}, - ExecutionPlan, }; async fn run_plan_and_format(plan: Arc) -> Result { diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 936c02eb2a02d..b7ba661d2343a 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -17,18 +17,26 @@ //! Physical Optimizer integration tests +#[expect(clippy::needless_pass_by_value)] mod aggregate_statistics; mod combine_partial_final_agg; +#[expect(clippy::needless_pass_by_value)] mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; mod filter_pushdown; mod join_selection; +#[expect(clippy::needless_pass_by_value)] mod limit_pushdown; mod limited_distinct_aggregation; mod partition_statistics; mod projection_pushdown; +mod pushdown_sort; mod replace_with_order_preserving_variants; mod sanity_checker; +#[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; +mod window_topn; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 49dc5b845605d..f84d79146b24d 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -25,36 +25,41 @@ mod test { use datafusion::datasource::listing::ListingTable; use datafusion::prelude::SessionContext; use datafusion_catalog::TableProvider; - use datafusion_common::stats::Precision; use datafusion_common::Result; - use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; - use datafusion_execution::config::SessionConfig; + use datafusion_common::stats::Precision; + use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, ScalarValue, Statistics, + }; use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_expr_common::operator::Operator; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::{binary, col, lit, Column}; use datafusion_physical_expr::Partitioning; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{Column, binary, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; - use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common::compute_record_batch_statistics; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; - use datafusion_physical_plan::joins::CrossJoinExec; + use datafusion_physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + }; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; + use datafusion_physical_plan::windows::{WindowAggExec, create_window_expr}; use datafusion_physical_plan::{ - execute_stream_partitioned, get_plan_string, ExecutionPlan, - ExecutionPlanProperties, + ExecutionPlan, ExecutionPlanProperties, execute_stream_partitioned, + get_plan_string, }; use futures::TryStreamExt; @@ -67,7 +72,7 @@ mod test { /// - Each partition has an "id" column (INT) with the following values: /// - First partition: [3, 4] /// - Second partition: [1, 2] - /// - Each row is 110 bytes in size + /// - Each partition has 16 bytes total (Int32 id: 4 bytes × 2 rows + Date32 date: 4 bytes × 2 rows) /// /// @param create_table_sql Optional parameter to set the create table SQL /// @param target_partition Optional parameter to set the target partitions @@ -101,40 +106,61 @@ mod test { .await .unwrap(); let table = ctx.table_provider(table_name.as_str()).await.unwrap(); - let listing_table = table - .as_any() - .downcast_ref::() - .unwrap() - .clone(); + let listing_table = table.downcast_ref::().unwrap().clone(); listing_table .scan(&ctx.state(), None, &[], None) .await .unwrap() } + // Date32 values for test data (days since 1970-01-01): + // 2025-03-01 = 20148 + // 2025-03-02 = 20149 + // 2025-03-03 = 20150 + // 2025-03-04 = 20151 + const DATE_2025_03_01: i32 = 20148; + const DATE_2025_03_02: i32 = 20149; + const DATE_2025_03_03: i32 = 20150; + const DATE_2025_03_04: i32 = 20151; + /// Helper function to create expected statistics for a partition with Int32 column + /// + /// If `date_range` is provided, includes exact statistics for the partition date column. + /// Partition column statistics are exact because all rows in a partition share the same value. fn create_partition_statistics( num_rows: usize, total_byte_size: usize, min_value: i32, max_value: i32, - include_date_column: bool, + date_range: Option<(i32, i32)>, ) -> Statistics { + // Int32 is 4 bytes per row + let int32_byte_size = num_rows * 4; let mut column_stats = vec![ColumnStatistics { null_count: Precision::Exact(0), max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Exact(int32_byte_size), }]; - if include_date_column { + if let Some((min_date, max_date)) = date_range { + // Partition column stats are computed from partition values: + // - null_count = 0 (partition values from paths are never null) + // - min/max are the merged partition values across files in the group + // - byte_size = num_rows * 4 (Date32 is 4 bytes per row) + // - distinct_count = Inexact(max_date - min_date + 1), derived from the + // date range via interval analysis for temporal types + let date32_byte_size = num_rows * 4; + let distinct_dates = (max_date - min_date + 1) as usize; column_stats.push(ColumnStatistics { - null_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(max_date))), + min_value: Precision::Exact(ScalarValue::Date32(Some(min_date))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(distinct_dates), + byte_size: Precision::Exact(date32_byte_size), }); } @@ -214,14 +240,26 @@ mod test { let statistics = (0..scan.output_partitioning().partition_count()) .map(|idx| scan.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Check the statistics of each partition assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -246,14 +284,15 @@ mod test { let statistics = (0..projection.output_partitioning().partition_count()) .map(|idx| projection.partition_statistics(Some(idx))) .collect::>>()?; + // Projection only includes id column, not the date partition column let expected_statistic_partition_1 = - create_partition_statistics(2, 8, 3, 4, false); + create_partition_statistics(2, 8, 3, 4, None); let expected_statistic_partition_2 = - create_partition_statistics(2, 8, 1, 2, false); + create_partition_statistics(2, 8, 1, 2, None); // Check the statistics of each partition assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -277,10 +316,16 @@ mod test { let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All 4 files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); assert_eq!(statistics.len(), 1); - assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(*statistics[0], expected_statistic_partition); // Check the statistics_by_partition with real results let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; validate_statistics_with_data(sort_exec.clone(), expected_stats, 0).await?; @@ -291,16 +336,28 @@ mod test { let sort_exec: Arc = Arc::new( SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), ); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -324,34 +381,61 @@ mod test { let filter: Arc = Arc::new(FilterExec::try_new(predicate, scan)?); let full_statistics = filter.partition_statistics(None)?; + // Filter preserves original total_rows and byte_size from input + // (4 total rows = 2 partitions * 2 rows each, byte_size = 4 * 4 = 16 bytes for int32) let expected_full_statistic = Statistics { num_rows: Precision::Inexact(0), total_byte_size: Precision::Inexact(0), column_statistics: vec![ ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), }, ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), // 4 rows * 4 bytes (Date32) }, ], }; - assert_eq!(full_statistics, expected_full_statistic); + assert_eq!(*full_statistics, expected_full_statistic); let statistics = (0..filter.output_partitioning().partition_count()) .map(|idx| filter.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_full_statistic); - assert_eq!(statistics[1], expected_full_statistic); + // Per-partition stats: each partition has 2 rows, byte_size = 2 * 4 = 8 + let expected_partition_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), // 2 rows * 4 bytes (Date32) + }, + ], + }; + assert_eq!(*statistics[0], expected_partition_statistic); + assert_eq!(*statistics[1], expected_partition_statistic); Ok(()) } @@ -365,18 +449,30 @@ mod test { .collect::>>()?; // Check that we have 4 partitions (2 from each scan) assert_eq!(statistics.len(), 4); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Verify first partition (from first scan) - assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[0], expected_statistic_partition_1); // Verify second partition (from first scan) - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Verify third partition (from second scan - same as first partition) - assert_eq!(statistics[2], expected_statistic_partition_1); + assert_eq!(*statistics[2], expected_statistic_partition_1); // Verify fourth partition (from second scan - same as second partition) - assert_eq!(statistics[3], expected_statistic_partition_2); + assert_eq!(*statistics[3], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -416,16 +512,17 @@ mod test { .collect::>>()?; assert_eq!(stats.len(), 2); + // Each partition gets half of combined input, total_rows per partition = 4 let expected_stats = Statistics { num_rows: Precision::Inexact(4), - total_byte_size: Precision::Inexact(220), + total_byte_size: Precision::Inexact(32), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], }; - assert_eq!(stats[0], expected_stats); - assert_eq!(stats[1], expected_stats); + assert_eq!(*stats[0], expected_stats); + assert_eq!(*stats[1], expected_stats); // Verify the execution results let partitions = execute_stream_partitioned( @@ -461,30 +558,78 @@ mod test { .collect::>>()?; // Check that we have 2 partitions assert_eq!(statistics.len(), 2); - let mut expected_statistic_partition_1 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_1 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(4))), - min_value: Precision::Exact(ScalarValue::Int32(Some(3))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); - let mut expected_statistic_partition_2 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_2 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(2))), - min_value: Precision::Exact(ScalarValue::Int32(Some(1))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + // Cross join output schema: [left.id, left.date, right.id] + // Cross join doesn't propagate Column's byte_size + let expected_statistic_partition_1 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(4), + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 0: ids [3,4] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; + let expected_statistic_partition_2 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(4), + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 1: ids [1,2] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -496,27 +641,77 @@ mod test { } #[tokio::test] - async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { - let scan = create_scan_exec_with_statistics(None, Some(2)).await; - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new(scan, 2)); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); - let statistics = (0..coalesce_batches.output_partitioning().partition_count()) - .map(|idx| coalesce_batches.partition_statistics(Some(idx))) + async fn test_statistic_by_partition_of_nested_loop_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left_scan = create_scan_exec_with_statistics(None, Some(2)).await; + let left_scan_coalesced: Arc = + Arc::new(CoalescePartitionsExec::new(left_scan)); + + let right_scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let nested_loop_join: Arc = + Arc::new(NestedLoopJoinExec::try_new( + left_scan_coalesced, + right_scan, + None, + &JoinType::RightSemi, + None, + )?); + + // Test partition_statistics(None) - returns overall statistics + // For RightSemi join, output columns come from right side only + let full_statistics = nested_loop_join.partition_statistics(None)?; + // With empty join columns, estimate_join_statistics returns Inexact row count + // based on the outer side (right side for RightSemi) + let mut expected_full_statistics = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); + expected_full_statistics.num_rows = Precision::Inexact(4); + expected_full_statistics.total_byte_size = Precision::Absent; + assert_eq!(*full_statistics, expected_full_statistics); + + // Test partition_statistics(Some(idx)) - returns partition-specific statistics + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let mut expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + expected_statistic_partition_1.num_rows = Precision::Inexact(2); + expected_statistic_partition_1.total_byte_size = Precision::Absent; + + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let mut expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); + expected_statistic_partition_2.num_rows = Precision::Inexact(2); + expected_statistic_partition_2.total_byte_size = Precision::Absent; + + let statistics = (0..nested_loop_join.output_partitioning().partition_count()) + .map(|idx| nested_loop_join.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ ExpectedStatistics::NonEmpty(3, 4, 2), ExpectedStatistics::NonEmpty(1, 2, 2), ]; - validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; + validate_statistics_with_data(nested_loop_join, expected_stats, 0).await?; + Ok(()) } @@ -525,13 +720,19 @@ mod test { let scan = create_scan_exec_with_statistics(None, Some(2)).await; let coalesce_partitions: Arc = Arc::new(CoalescePartitionsExec::new(scan)); - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 1); - assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(*statistics[0], expected_statistic_partition); // Check the statistics_by_partition with real results let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; @@ -548,20 +749,20 @@ mod test { .map(|idx| local_limit.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - let mut expected_0 = statistics[0].clone(); + let mut expected_0 = Statistics::clone(&statistics[0]); expected_0.column_statistics = expected_0 .column_statistics .into_iter() .map(|c| c.to_inexact()) .collect(); - let mut expected_1 = statistics[1].clone(); + let mut expected_1 = Statistics::clone(&statistics[1]); expected_1.column_statistics = expected_1 .column_statistics .into_iter() .map(|c| c.to_inexact()) .collect(); - assert_eq!(statistics[0], expected_0); - assert_eq!(statistics[1], expected_1); + assert_eq!(*statistics[0], expected_0); + assert_eq!(*statistics[1], expected_1); Ok(()) } @@ -575,9 +776,15 @@ mod test { .map(|idx| global_limit.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 1); - let expected_statistic_partition = - create_partition_statistics(2, 110, 3, 4, true); - assert_eq!(statistics[0], expected_statistic_partition); + // GlobalLimit takes from first partition: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + assert_eq!(*statistics[0], expected_statistic_partition); Ok(()) } @@ -601,11 +808,13 @@ mod test { ), ]); - let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .schema(Arc::clone(&scan_schema)) - .alias(String::from("COUNT(c)")) - .build() - .map(Arc::new)?]; + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?, + ]; let aggregate_exec_partial: Arc = Arc::new(AggregateExec::try_new( @@ -625,9 +834,10 @@ mod test { let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + // Aggregate doesn't propagate num_rows and ColumnStatistics byte_size from input let expected_p0_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -635,17 +845,18 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(3))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], }; - assert_eq!(&p0_statistics, &expected_p0_statistics); + assert_eq!(*p0_statistics, expected_p0_statistics); let expected_p1_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -653,6 +864,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -660,7 +872,7 @@ mod test { }; let p1_statistics = aggregate_exec_partial.partition_statistics(Some(1))?; - assert_eq!(&p1_statistics, &expected_p1_statistics); + assert_eq!(*p1_statistics, expected_p1_statistics); validate_statistics_with_data( aggregate_exec_partial.clone(), @@ -682,10 +894,10 @@ mod test { )?); let p0_statistics = agg_final.partition_statistics(Some(0))?; - assert_eq!(&p0_statistics, &expected_p0_statistics); + assert_eq!(*p0_statistics, expected_p0_statistics); let p1_statistics = agg_final.partition_statistics(Some(1))?; - assert_eq!(&p1_statistics, &expected_p1_statistics); + assert_eq!(*p1_statistics, expected_p1_statistics); validate_statistics_with_data( agg_final.clone(), @@ -720,14 +932,17 @@ mod test { num_rows: Precision::Exact(0), total_byte_size: Precision::Absent, column_statistics: vec![ - ColumnStatistics::new_unknown(), + ColumnStatistics { + distinct_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], }; - assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(0))?); - assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(1))?); + assert_eq!(empty_stat, *agg_partial.partition_statistics(Some(0))?); + assert_eq!(empty_stat, *agg_partial.partition_statistics(Some(1))?); validate_statistics_with_data( agg_partial.clone(), vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], @@ -753,8 +968,8 @@ mod test { agg_partial.schema(), )?); - assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(0))?); - assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(1))?); + assert_eq!(empty_stat, *agg_final.partition_statistics(Some(0))?); + assert_eq!(empty_stat, *agg_final.partition_statistics(Some(1))?); validate_statistics_with_data( agg_final, @@ -790,7 +1005,7 @@ mod test { column_statistics: vec![ColumnStatistics::new_unknown()], }; - assert_eq!(&expect_stat, &agg_final.partition_statistics(Some(0))?); + assert_eq!(expect_stat, *agg_final.partition_statistics(Some(0))?); // Verify that the aggregate final result has exactly one partition with one row let mut partitions = execute_stream_partitioned( @@ -824,13 +1039,13 @@ mod test { &schema, None, ); - assert_eq!(actual, expected); + assert_eq!(*actual, expected); all_batches.push(batches); } let actual = plan.partition_statistics(None)?; let expected = compute_record_batch_statistics(&all_batches, &schema, None); - assert_eq!(actual, expected); + assert_eq!(*actual, expected); Ok(()) } @@ -849,9 +1064,10 @@ mod test { .collect::>>()?; assert_eq!(statistics.len(), 3); + // Repartition preserves original total_rows from input (4 rows total) let expected_stats = Statistics { num_rows: Precision::Inexact(1), - total_byte_size: Precision::Inexact(73), + total_byte_size: Precision::Inexact(10), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -860,7 +1076,7 @@ mod test { // All partitions should have the same statistics for stat in statistics.iter() { - assert_eq!(stat, &expected_stats); + assert_eq!(**stat, expected_stats); } // Verify that the result has exactly 3 partitions @@ -878,9 +1094,9 @@ mod test { partition_row_counts.push(total_rows); } assert_eq!(partition_row_counts.len(), 3); - assert_eq!(partition_row_counts[0], 2); + assert_eq!(partition_row_counts[0], 1); assert_eq!(partition_row_counts[1], 2); - assert_eq!(partition_row_counts[2], 0); + assert_eq!(partition_row_counts[2], 1); Ok(()) } @@ -898,9 +1114,11 @@ mod test { let result = repartition.partition_statistics(Some(2)); assert!(result.is_err()); let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("RepartitionExec invalid partition 2 (expected less than 2)")); + assert!( + error + .to_string() + .contains("RepartitionExec invalid partition 2 (expected less than 2)") + ); let partitions = execute_stream_partitioned( repartition.clone(), @@ -923,7 +1141,7 @@ mod test { )?); let result = repartition.partition_statistics(Some(0))?; - assert_eq!(result, Statistics::new_unknown(&scan_schema)); + assert_eq!(*result, Statistics::new_unknown(&scan_schema)); // Verify that the result has exactly 0 partitions let partitions = execute_stream_partitioned( @@ -953,16 +1171,17 @@ mod test { .collect::>>()?; assert_eq!(stats.len(), 2); + // Repartition preserves original total_rows from input (4 rows total) let expected_stats = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Inexact(110), + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], }; - assert_eq!(stats[0], expected_stats); - assert_eq!(stats[1], expected_stats); + assert_eq!(*stats[0], expected_stats); + assert_eq!(*stats[1], expected_stats); // Verify the repartition execution results let partitions = @@ -980,4 +1199,412 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_statistic_by_partition_of_window_agg() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let window_expr = create_window_expr( + &WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count".to_owned(), + &[col("id", &scan.schema())?], + &[], // no partition by + &[PhysicalSortExpr::new( + col("id", &scan.schema())?, + SortOptions::default(), + )], + Arc::new(WindowFrame::new(Some(false))), + scan.schema(), + false, + false, + None, + )?; + + let window_agg: Arc = + Arc::new(WindowAggExec::try_new(vec![window_expr], scan, true)?); + + // Verify partition statistics are properly propagated (not unknown) + let statistics = (0..window_agg.output_partitioning().partition_count()) + .map(|idx| window_agg.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 2); + + // Window functions preserve input row counts and column statistics + // but add unknown statistics for the new window column + let expected_statistic_partition_1 = Statistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(2), + byte_size: Precision::Exact(8), + }, + ColumnStatistics::new_unknown(), // window column + ], + }; + + let expected_statistic_partition_2 = Statistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_03, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(2), + byte_size: Precision::Exact(8), + }, + ColumnStatistics::new_unknown(), // window column + ], + }; + + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); + + // Verify the statistics match actual execution results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(window_agg, expected_stats, 0).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_empty_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Try to test with single partition + let empty_single = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let stats = empty_single.partition_statistics(Some(0))?; + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); + assert_eq!(stats.column_statistics.len(), 2); + + for col_stat in &stats.column_statistics { + assert_eq!(col_stat.null_count, Precision::Exact(0)); + assert_eq!(col_stat.distinct_count, Precision::Exact(0)); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + assert_eq!(col_stat.min_value, Precision::::Absent); + assert_eq!(col_stat.max_value, Precision::::Absent); + assert_eq!(col_stat.sum_value, Precision::::Absent); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + } + + let overall_stats = empty_single.partition_statistics(None)?; + assert_eq!(stats, overall_stats); + + validate_statistics_with_data(empty_single, vec![ExpectedStatistics::Empty], 0) + .await?; + + // Test with multiple partitions + let empty_multi: Arc = + Arc::new(EmptyExec::new(Arc::clone(&schema)).with_partitions(3)); + + let statistics = (0..empty_multi.output_partitioning().partition_count()) + .map(|idx| empty_multi.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 3); + + for stat in &statistics { + assert_eq!(stat.num_rows, Precision::Exact(0)); + assert_eq!(stat.total_byte_size, Precision::Exact(0)); + } + + validate_statistics_with_data( + empty_multi, + vec![ + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ], + 0, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_hash_join_partition_statistics() -> Result<()> { + // Create left table scan and coalesce to 1 partition for CollectLeft mode + let left_scan = create_scan_exec_with_statistics(None, Some(2)).await; + let left_scan_coalesced = Arc::new(CoalescePartitionsExec::new(left_scan.clone())) + as Arc; + + // Create right table scan with different table name + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + + // Create join condition: t1.id = t2.id + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + // Test CollectLeft mode - left child must have 1 partition + let collect_left_join = Arc::new(HashJoinExec::try_new( + left_scan_coalesced, + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for CollectLeft mode + let statistics = (0..collect_left_join.output_partitioning().partition_count()) + .map(|idx| collect_left_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For collect left mode, the min/max values are from the entire left table and the specific partition of the right table. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + // NDV is Inexact(4) derived from the date range via interval analysis + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(4), + byte_size: Precision::Exact(16), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + // NDV is Inexact(2) derived from the date range via interval analysis + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(2), + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + + // Test Partitioned mode + let partitioned_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Partitioned mode + let statistics = (0..partitioned_join.output_partitioning().partition_count()) + .map(|idx| partitioned_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For partitioned mode, the min/max values are from the specific partition for each side. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Left date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(2), + byte_size: Precision::Exact(8), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(2), + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + + // Test Auto mode - should fall back to getting all partition statistics + let auto_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Auto mode + let statistics = (0..auto_join.output_partitioning().partition_count()) + .map(|idx| auto_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For auto mode, the min/max values are from the entire left and right tables. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(4), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(4), + byte_size: Precision::Exact(16), + }, + // Right id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Right date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(4), + byte_size: Precision::Exact(16), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + Ok(()) + } } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 8631613c3925e..6635220cf2028 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::compute::SortOptions; @@ -24,8 +23,9 @@ use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion_datasource::TableSchema; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -34,30 +34,31 @@ use datafusion_expr::{ }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ - binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, }; use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::coop::CooperativeExec; +use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::projection::{update_expr, ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr, update_expr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::{displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; use insta::assert_snapshot; use itertools::Itertools; @@ -77,10 +78,6 @@ impl DummyUDF { } impl ScalarUDFImpl for DummyUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "dummy_udf" } @@ -229,9 +226,11 @@ fn test_update_matching_exprs() -> Result<()> { .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &child_exprs, true)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &child_exprs, true)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -368,9 +367,11 @@ fn test_update_projected_exprs() -> Result<()> { .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &proj_exprs, false)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &proj_exprs, false)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -384,14 +385,20 @@ fn create_simple_csv_exec() -> Arc { Field::new("d", DataType::Int32, true), Field::new("e", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection_indices(Some(vec![0, 1, 2, 3, 4])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![0, 1, 2, 3, 4])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -403,14 +410,20 @@ fn create_projecting_csv_exec() -> Arc { Field::new("c", DataType::Int32, true), Field::new("d", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection_indices(Some(vec![3, 2, 1])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![3, 2, 1])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -432,8 +445,8 @@ fn test_csv_after_projection() -> Result<()> { let csv = create_projecting_csv_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d"), ], csv.clone(), )?); @@ -469,9 +482,9 @@ fn test_memory_after_projection() -> Result<()> { let memory = create_projecting_memory_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a"), ], memory.clone(), )?); @@ -502,11 +515,9 @@ fn test_memory_after_projection() -> Result<()> { assert_eq!( after_optimize .clone() - .as_any() .downcast_ref::() .unwrap() .data_source() - .as_any() .downcast_ref::() .unwrap() .projection() @@ -575,9 +586,9 @@ fn test_streaming_table_after_projection() -> Result<()> { )?; let projection = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), ], Arc::new(streaming_table) as _, )?) as _; @@ -585,10 +596,7 @@ fn test_streaming_table_after_projection() -> Result<()> { let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let result = after_optimize - .as_any() - .downcast_ref::() - .unwrap(); + let result = after_optimize.downcast_ref::().unwrap(); assert_eq!( result.partition_schema(), &Arc::new(Schema::new(vec![ @@ -642,28 +650,25 @@ fn test_projection_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let child_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b"), ], csv.clone(), )?); let top_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b"), ProjectionExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("c", 0)), Operator::Plus, Arc::new(Column::new("new_e", 1)), )), - "binary".to_string(), - ), - ProjectionExpr::new( - Arc::new(Column::new("new_b", 3)), - "newest_b".to_string(), + "binary", ), + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "newest_b"), ], child_projection.clone(), )?); @@ -692,10 +697,7 @@ fn test_projection_after_projection() -> Result<()> { assert_snapshot!( actual, - @r" - ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - " + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b], file_type=csv, has_header=false" ); Ok(()) @@ -731,9 +733,9 @@ fn test_output_req_after_projection() -> Result<()> { )); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], sort_req.clone(), )?); @@ -762,8 +764,7 @@ fn test_output_req_after_projection() -> Result<()> { actual, @r" OutputRequirementExec: order_by=[(b@2, asc), (c@0 + new_a@1, asc)], dist_by=HashPartitioned[[new_a@1, b@2]]) - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -786,7 +787,6 @@ fn test_output_req_after_projection() -> Result<()> { ); assert_eq!( after_optimize - .as_any() .downcast_ref::() .unwrap() .required_input_ordering()[0] @@ -799,16 +799,16 @@ fn test_output_req_after_projection() -> Result<()> { Arc::new(Column::new("b", 2)), ]; if let Distribution::HashPartitioned(vec) = after_optimize - .as_any() .downcast_ref::() .unwrap() .required_input_distribution()[0] .clone() { - assert!(vec - .iter() - .zip(expected_distribution) - .all(|(actual, expected)| actual.eq(&expected))); + assert!( + vec.iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected)) + ); } else { panic!("Expected HashPartitioned distribution!"); }; @@ -823,9 +823,9 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { Arc::new(CoalescePartitionsExec::new(csv)); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], coalesce_partitions, )?); @@ -853,8 +853,7 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { actual, @r" CoalescePartitionsExec - ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, a@0 as a_new, d], file_type=csv, has_header=false " ); @@ -880,9 +879,9 @@ fn test_filter_after_projection() -> Result<()> { let filter = Arc::new(FilterExec::try_new(predicate, csv)?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], filter.clone(), )?) as _; @@ -911,8 +910,7 @@ fn test_filter_after_projection() -> Result<()> { actual, @r" FilterExec: b@1 - a_new@0 > d@2 - a_new@0 - ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_new, b, d], file_type=csv, has_header=false " ); @@ -975,17 +973,11 @@ fn test_join_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - ProjectionExpr::new( - Arc::new(Column::new("a", 5)), - "a_from_right".to_string(), - ), - ProjectionExpr::new( - Arc::new(Column::new("c", 7)), - "c_from_right".to_string(), - ), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a_from_right"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join, )?) as _; @@ -1014,10 +1006,8 @@ fn test_join_after_projection() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 - ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_from_right, c@2 as c_from_right], file_type=csv, has_header=false " ); @@ -1039,7 +1029,6 @@ fn test_join_after_projection() -> Result<()> { assert_eq!( expected_filter_col_ind, after_optimize - .as_any() .downcast_ref::() .unwrap() .filter() @@ -1106,16 +1095,16 @@ fn test_join_after_required_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e"), ], join, )?) as _; @@ -1195,7 +1184,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )?) as _; let projection: Arc = Arc::new(ProjectionExec::try_new( - vec![ProjectionExpr::new(col_left_c, "c".to_string())], + vec![ProjectionExpr::new(col_left_c, "c")], Arc::clone(&join), )?) as _; let initial = displayable(projection.as_ref()).indent(true).to_string(); @@ -1282,16 +1271,14 @@ fn test_hash_join_after_projection() -> Result<()> { None, PartitionMode::Auto, NullEquality::NullEqualsNothing, + false, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - ProjectionExpr::new( - Arc::new(Column::new("c", 7)), - "c_from_right".to_string(), - ), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join.clone(), )?) as _; @@ -1318,8 +1305,8 @@ fn test_hash_join_after_projection() -> Result<()> { assert_snapshot!( actual, @r" - ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right] - HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7] + ProjectionExec: expr=[c@0 as c_from_left, b@1 as b_from_left, a@2 as a_from_left, c@3 as c_from_right] + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[c@2, b@1, a@0, c@7] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false " @@ -1327,10 +1314,10 @@ fn test_hash_join_after_projection() -> Result<()> { let projection = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), ], join.clone(), )?); @@ -1371,9 +1358,9 @@ fn test_repartition_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new"), ], repartition, )?) as _; @@ -1399,14 +1386,12 @@ fn test_repartition_after_projection() -> Result<()> { actual, @r" RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1 - ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as b_new, a, d@3 as d_new], file_type=csv, has_header=false " ); assert_eq!( after_optimize - .as_any() .downcast_ref::() .unwrap() .partitioning() @@ -1441,9 +1426,9 @@ fn test_sort_after_projection() -> Result<()> { ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], Arc::new(sort_exec), )?) as _; @@ -1470,8 +1455,7 @@ fn test_sort_after_projection() -> Result<()> { actual, @r" SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1495,9 +1479,9 @@ fn test_sort_preserving_after_projection() -> Result<()> { ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], Arc::new(sort_exec), )?) as _; @@ -1524,8 +1508,7 @@ fn test_sort_preserving_after_projection() -> Result<()> { actual, @r" SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC] - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1538,9 +1521,9 @@ fn test_union_after_projection() -> Result<()> { let union = UnionExec::try_new(vec![csv.clone(), csv.clone(), csv])?; let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], union.clone(), )?) as _; @@ -1569,12 +1552,9 @@ fn test_union_after_projection() -> Result<()> { actual, @r" UnionExec - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1589,14 +1569,23 @@ fn partitioned_data_source() -> Arc { Field::new("string_col", DataType::Utf8, true), ])); + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("partition_col", DataType::Utf8, true))], + ); let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - file_schema.clone(), - Arc::new(CsvSource::default()), + Arc::new(CsvSource::new(table_schema).with_csv_options(options)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_table_partition_cols(vec![Field::new("partition_col", DataType::Utf8, true)]) + .with_file(PartitionedFile::new("x", 100)) .with_projection_indices(Some(vec![0, 1, 2])) + .unwrap() .build(); DataSourceExec::from_data_source(config) @@ -1611,16 +1600,13 @@ fn test_partition_col_projection_pushdown() -> Result<()> { vec![ ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), ProjectionExpr::new( col("partition_col", partitioned_schema.as_ref())?, - "partition_col".to_string(), - ), - ProjectionExpr::new( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1634,10 +1620,7 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let actual = after_optimize_string.trim(); assert_snapshot!( actual, - @r" - ProjectionExec: expr=[string_col@1 as string_col, partition_col@2 as partition_col, int_col@0 as int_col] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false - " + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, partition_col, int_col], file_type=csv, has_header=false" ); Ok(()) @@ -1652,7 +1635,7 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { vec![ ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), ProjectionExpr::new( // CAST(partition_col, Utf8View) @@ -1661,12 +1644,9 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { partitioned_schema.as_ref(), DataType::Utf8View, )?, - "partition_col".to_string(), - ), - ProjectionExpr::new( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1678,11 +1658,214 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { .indent(true) .to_string(); let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col], file_type=csv, has_header=false" + ); + + Ok(()) +} + +#[test] +fn test_cooperative_exec_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let cooperative: Arc = Arc::new(CooperativeExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ], + cooperative, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a, b@1 as b] + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed down through CooperativeExec + assert_snapshot!( + actual, + @r" + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false + " + ); + + Ok(()) +} + +#[test] +fn test_hash_join_empty_projection_embeds() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join = Arc::new(HashJoinExec::try_new( + left_csv, + right_csv, + vec![(Arc::new(Column::new("a", 0)), Arc::new(Column::new("a", 0)))], + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?); + + // Empty projection: no columns needed from the join output + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![] as Vec, + join, + )?); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // The empty projection should be embedded into the HashJoinExec, + // resulting in projection=[] on the join and no ProjectionExec wrapper. + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, a@0)], projection=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + Ok(()) +} + +/// Regression test for +/// +/// When a `ProjectionExec` sits on top of a `FilterExec` that already carries +/// an embedded projection, the `ProjectionPushdown` optimizer must not panic. +/// +/// Before the fix, `FilterExecBuilder::from(self)` copied stale projection +/// indices (e.g. `[0, 1, 2]`). After swapping, the new input was narrower +/// (2 columns), so `.build()` panicked with "project index out of bounds". +#[test] +fn test_filter_with_embedded_projection_after_projection() -> Result<()> { + // DataSourceExec: [a, b, c, d, e] + let csv = create_simple_csv_exec(); + + // FilterExec: a > 0, projection=[0, 1, 2] → output: [a, b, c] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )); + let filter: Arc = Arc::new( + FilterExecBuilder::new(predicate, csv) + .apply_projection(Some(vec![0, 1, 2]))? + .build()?, + ); + + // ProjectionExec: narrows [a, b, c] → [a, b] + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ], + filter, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a, b@1 as b] + FilterExec: a@0 > 0, projection=[a@0, b@1, c@2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + // This must not panic + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + FilterExec: a@0 > 0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false + " + ); + + Ok(()) +} + +/// Same as above, but the outer ProjectionExec also renames columns. +/// Ensures the rename is preserved after the projection pushdown swap. +#[test] +fn test_filter_with_embedded_projection_after_renaming_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + + // FilterExec: b > 10, projection=[0, 1, 2, 3] → output: [a, b, c, d] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter: Arc = Arc::new( + FilterExecBuilder::new(predicate, csv) + .apply_projection(Some(vec![0, 1, 2, 3]))? + .build()?, + ); + + // ProjectionExec: [a as x, b as y] — narrows and renames + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "x"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "y"), + ], + filter, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as x, b@1 as y] + FilterExec: b@1 > 10, projection=[a@0, b@1, c@2, d@3] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); assert_snapshot!( actual, @r" - ProjectionExec: expr=[string_col@1 as string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col@0 as int_col] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false + FilterExec: y@1 > 10 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as x, b@1 as y], file_type=csv, has_header=false " ); diff --git a/datafusion/core/tests/physical_optimizer/pushdown_sort.rs b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs new file mode 100644 index 0000000000000..e2700c3174a16 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs @@ -0,0 +1,1086 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for sort pushdown optimizer rule (Phase 1) +//! +//! Phase 1 tests verify that: +//! 1. Reverse scan is enabled (reverse_row_groups=true) +//! 2. SortExec is kept (because ordering is inexact) +//! 3. output_ordering remains unchanged +//! 4. Early termination is enabled for TopK queries +//! 5. Prefix matching works correctly + +use datafusion_physical_expr::expressions; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::pushdown_sort::PushdownSort; +use std::sync::Arc; + +use crate::physical_optimizer::test_utils::{ + OptimizationTest, TestScan, coalesce_partitions_exec, parquet_exec, + parquet_exec_with_sort, projection_exec, projection_exec_with_alias, + repartition_exec, schema, simple_projection_exec, sort_exec, sort_exec_with_fetch, + sort_expr, sort_expr_named, test_scan_with_ordering, +}; + +#[test] +fn test_sort_pushdown_disabled() { + // When pushdown is disabled, plan should remain unchanged + let schema = schema(); + let source = parquet_exec(schema.clone()); + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), false), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_basic_phase1() { + // Phase 1: Reverse scan enabled, Sort kept, output_ordering unchanged + let schema = schema(); + + // Source has ASC NULLS LAST ordering (default) + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC NULLS LAST ordering (exact reverse) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_with_limit_phase1() { + // Phase 1: Sort with fetch enables early termination but keeps Sort + let schema = schema(); + + // Source has ASC ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC ordering with limit + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_multiple_columns_phase1() { + // Phase 1: Sort on multiple columns - reverse multi-column ordering + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC NULLS FIRST, b DESC] ordering (exact reverse) + let reverse_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PREFIX MATCHING TESTS +// ============================================================================ + +#[test] +fn test_prefix_match_single_column() { + // Test prefix matching: source has [a DESC, b ASC], query needs [a ASC] + // After reverse: [a ASC, b DESC] which satisfies [a ASC] prefix + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request only [a ASC NULLS FIRST] - a prefix of the reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_with_limit() { + // Test prefix matching with LIMIT - important for TopK optimization + let schema = schema(); + + // Source has [a ASC, b DESC, c ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone(), b.clone().reverse(), c]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC NULLS LAST, b ASC NULLS FIRST] with LIMIT 100 + // This is a prefix (2 columns) of the reversed 3-column ordering + let prefix_ordering = + LexOrdering::new(vec![a.reverse(), b.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec_with_fetch(prefix_ordering, Some(100), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 DESC NULLS LAST, c@2 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_through_transparent_nodes() { + // Test prefix matching works through transparent nodes + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC, c DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone().reverse(), b, c.reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + + // Request only [a ASC NULLS FIRST] - prefix of reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC, c@2 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_exact_prefix_match_same_direction() { + // Test that when the requested sort [a DESC] matches a prefix of the source's + // natural ordering [a DESC, b ASC], the Sort is eliminated (Exact pushdown). + let schema = schema(); + + // Source has [a DESC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC] - same direction as source prefix, Sort should be eliminated + let same_direction = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let plan = sort_exec(same_direction, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_no_prefix_match_longer_than_source() { + // Test that prefix matching does NOT work if requested is longer than source + let schema = schema(); + + // Source has [a DESC] ordering (single column) + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC, b DESC] - longer than source, can't be a prefix + let longer_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(longer_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + " + ); +} + +// ============================================================================ +// ORIGINAL TESTS +// ============================================================================ + +#[test] +fn test_sort_through_repartition() { + // Sort should push through RepartitionExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_nested_sorts() { + // Nested sort operations - only innermost can be optimized + let schema = schema(); + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let inner_sort = sort_exec(desc_ordering, source); + + let sort_exprs2 = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs2, inner_sort); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_non_sort_plans_unchanged() { + // Plans without SortExec should pass through unchanged + let schema = schema(); + let plan = parquet_exec(schema.clone()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_optimizer_properties() { + // Test optimizer metadata + let optimizer = PushdownSort::new(); + + assert_eq!(optimizer.name(), "PushdownSort"); + assert!(optimizer.schema_check()); +} + +#[test] +fn test_sort_through_coalesce_partitions() { + // Sort should push through CoalescePartitionsExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_complex_plan_with_multiple_operators() { + // Test a complex plan with multiple operators between sort and source + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_multiple_sorts_different_columns() { + // Test nested sorts on different columns - only innermost can optimize + let schema = schema(); + let a = sort_expr("a", &schema); + let c = sort_expr("c", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // First sort by column 'a' DESC (reverse of source) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let sort1 = sort_exec(desc_ordering, source); + + // Then sort by column 'c' (different column, can't optimize) + let sort_exprs2 = LexOrdering::new(vec![c]).unwrap(); + let plan = sort_exec(sort_exprs2, sort1); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_pushdown_for_unordered_source() { + // Verify pushdown does NOT happen for sources without ordering + let schema = schema(); + let source = parquet_exec(schema.clone()); // No output_ordering + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_no_pushdown_for_non_reverse_sort() { + // Verify pushdown does NOT happen when sort doesn't reverse source ordering + let schema = schema(); + + // Source sorted by 'a' ASC + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request sort by 'b' (different column) + let sort_exprs = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_pushdown_through_blocking_node() { + // Test that pushdown works for inner sort even when outer sort is blocked + // Structure: Sort -> Aggregate (blocks pushdown) -> Sort -> Scan + // The outer sort can't push through aggregate, but the inner sort should still optimize + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use std::sync::Arc; + + let schema = schema(); + + // Bottom: DataSource with [a ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Inner Sort: [a DESC NULLS FIRST] - exact reverse, CAN push down to source + let inner_sort_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let inner_sort = sort_exec(inner_sort_ordering, source); + + // Middle: Aggregate (blocks pushdown from outer sort) + // GROUP BY a, COUNT(b) + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(expressions::Column::new("a", 0)) as _, + "a".to_string(), + )]); + + let count_expr = Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![Arc::new(expressions::Column::new("b", 1)) as _], + ) + .schema(Arc::clone(&schema)) + .alias("COUNT(b)") + .build() + .unwrap(), + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![count_expr], + vec![None], + inner_sort, + Arc::clone(&schema), + ) + .unwrap(), + ); + + // Outer Sort: [a ASC] - this CANNOT push down through aggregate + let outer_sort_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let plan = sort_exec(outer_sort_ordering, aggregate); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PROJECTION TESTS +// ============================================================================ + +#[test] +fn test_sort_pushdown_through_simple_projection() { + // Sort pushes through projection with simple column references + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b (simple column references) + let projection = simple_projection_exec(source, vec![0, 1]); // columns a, b + + // Request [a DESC] - should push through projection to source + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection_with_alias() { + // Sort pushes through projection with column aliases + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a AS id, b AS value + let projection = projection_exec_with_alias(source, vec![(0, "id"), (1, "value")]); + + // Request [id DESC] - should map to [a DESC] and push down + let id_expr = sort_expr_named("id", 0); + let desc_ordering = LexOrdering::new(vec![id_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_sort_pushdown_through_computed_projection() { + use datafusion_expr::Operator; + + // Sort should NOT push through projection with computed columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a+b as sum, c + let projection = projection_exec( + vec![ + ( + Arc::new(expressions::BinaryExpr::new( + Arc::new(expressions::Column::new("a", 0)), + Operator::Plus, + Arc::new(expressions::Column::new("b", 1)), + )) as Arc, + "sum".to_string(), + ), + ( + Arc::new(expressions::Column::new("c", 2)) as Arc, + "c".to_string(), + ), + ], + source, + ) + .unwrap(); + + // Request [sum DESC] - should NOT push down (sum is computed) + let sum_expr = sort_expr_named("sum", 0); + let desc_ordering = LexOrdering::new(vec![sum_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_projection_reordered_columns() { + // Sort pushes through projection that reorders columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT c, b, a (columns reordered) + let projection = simple_projection_exec(source, vec![2, 1, 0]); // c, b, a + + // Request [a DESC] where a is now at index 2 in projection output + let a_expr_at_2 = sort_expr_named("a", 2); + let desc_ordering = LexOrdering::new(vec![a_expr_at_2.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_with_limit() { + // Sort with LIMIT pushes through simple projection + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b + let projection = simple_projection_exec(source, vec![0, 1]); + + // Request [a DESC] with LIMIT 10 + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection() { + // Sort pushes through both projection and coalesce batches + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b + let projection = simple_projection_exec(source, vec![0, 1]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_subset_of_columns() { + // Sort pushes through projection that selects subset of columns + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a (subset of columns) + let projection = simple_projection_exec(source, vec![0]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// TESTSCAN DEMONSTRATION TESTS +// ============================================================================ +// These tests use TestScan to demonstrate how sort pushdown works more clearly +// than ParquetExec. TestScan can accept ANY ordering (not just reverse) and +// displays the requested ordering explicitly in the output. + +#[test] +fn test_sort_pushdown_with_test_scan_basic() { + // Demonstrates TestScan showing requested ordering clearly + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC] ordering + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC], requested_ordering=[a@0 DESC NULLS LAST] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_multi_column() { + // Demonstrates TestScan with multi-column ordering + let schema = schema(); + + // Source has [a ASC, b DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone().reverse()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC, b ASC] ordering (reverse of source) + let reverse_ordering = LexOrdering::new(vec![a.reverse(), b]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST], requested_ordering=[a@0 DESC NULLS LAST, b@1 ASC] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_arbitrary_ordering() { + // Demonstrates that TestScan can accept ANY ordering (not just reverse) + // This is different from ParquetExec which only supports reverse scans + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a ASC, b DESC] - NOT a simple reverse, but TestScan accepts it + let mixed_ordering = LexOrdering::new(vec![a, b.reverse()]).unwrap(); + let plan = sort_exec(mixed_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC] + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC], requested_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + " + ); +} + +// ============================================================================ +// EXACT PUSHDOWN TESTS (source guarantees ordering, SortExec removed) +// ============================================================================ + +#[test] +fn test_sort_pushdown_exact_no_fetch_no_limit() { + // When a source returns Exact (without fetch), the SortExec should be + // removed entirely with no GlobalLimitExec wrapper. + let schema = schema(); + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source = + Arc::new(TestScan::new(schema.clone(), vec![]).with_exact_pushdown(true)); + + let ordering = LexOrdering::new(vec![a, b.reverse()]).unwrap(); + let plan = sort_exec(ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan + output: + Ok: + - TestScan: requested_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + " + ); +} + +#[test] +fn test_sort_pushdown_exact_preserves_fetch_with_global_limit() { + // When a source returns Exact but does NOT support with_fetch(), + // the optimizer must wrap the result with GlobalLimitExec to preserve + // the LIMIT from the eliminated SortExec. + let schema = schema(); + let a = sort_expr("a", &schema); + let source = + Arc::new(TestScan::new(schema.clone(), vec![]).with_exact_pushdown(true)); + + let ordering = LexOrdering::new(vec![a]).unwrap(); + let plan = sort_exec_with_fetch(ordering, Some(10), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - TestScan + output: + Ok: + - GlobalLimitExec: skip=0, fetch=10 + - TestScan: requested_ordering=[a@0 ASC] + " + ); +} + +#[test] +fn test_sort_pushdown_exact_preserves_fetch_with_source_support() { + // When a source returns Exact AND supports with_fetch(), + // the limit should be pushed into the source directly (no GlobalLimitExec). + let schema = schema(); + let a = sort_expr("a", &schema); + let source = Arc::new( + TestScan::new(schema.clone(), vec![]) + .with_exact_pushdown(true) + .with_supports_fetch(true), + ); + + let ordering = LexOrdering::new(vec![a]).unwrap(); + let plan = sort_exec_with_fetch(ordering, Some(10), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 ASC], preserve_partitioning=[false] + - TestScan + output: + Ok: + - TestScan: requested_ordering=[a@0 ASC], fetch=10 + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 79% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index 7d8a9c7c2125c..8b659e757aa2a 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -18,33 +18,31 @@ use arrow::datatypes::SchemaRef; use arrow::{array::RecordBatch, compute::concat_batches}; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; -use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Result, config::ConfigOptions, internal_err}; use datafusion_datasource::{ - file::FileSource, file_scan_config::FileScanConfig, + PartitionedFile, file::FileSource, file_scan_config::FileScanConfig, file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, - file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, - schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, - TableSchema, + file_stream::FileOpener, source::DataSourceExec, }; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; use datafusion_physical_plan::{ - displayable, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, filter::FilterExec, filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPropagation, }, metrics::ExecutionPlanMetricsSet, - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use futures::StreamExt; use futures::{FutureExt, Stream}; use object_store::ObjectStore; use std::{ - any::Any, fmt::{Display, Formatter}, pin::Pin, sync::Arc, @@ -53,14 +51,17 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - schema: Option, - projection: Option>, + projection: Option, predicate: Option>, } impl FileOpener for TestOpener { fn open(&self, _partitioned_file: PartitionedFile) -> Result { let mut batches = self.batches.clone(); + if self.batches.is_empty() { + return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); + } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -71,27 +72,23 @@ impl FileOpener for TestOpener { } batches = new_batches.into_iter().collect(); } - if let Some(schema) = &self.schema { - let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); - let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); - let mut new_batches = Vec::new(); - for batch in batches { - let batch = if let Some(predicate) = &self.predicate { - batch_filter(&batch, predicate)? - } else { - batch - }; - let batch = batch.project(&projection).unwrap(); - let batch = mapper.map_batch(batch).unwrap(); - new_batches.push(batch); - } - batches = new_batches; + let mut new_batches = Vec::new(); + for batch in batches { + let batch = if let Some(predicate) = &self.predicate { + batch_filter(&batch, predicate)? + } else { + batch + }; + new_batches.push(batch); } + batches = new_batches; + if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } @@ -102,26 +99,28 @@ impl FileOpener for TestOpener { } /// A placeholder data source that accepts filter pushdown -#[derive(Clone, Default)] +#[derive(Clone)] pub struct TestSource { support: bool, predicate: Option>, - statistics: Option, batch_size: Option, batches: Vec, - schema: Option, metrics: ExecutionPlanMetricsSet, - projection: Option>, - schema_adapter_factory: Option>, + projection: Option, + table_schema: datafusion_datasource::TableSchema, } impl TestSource { - pub fn new(support: bool, batches: Vec) -> Self { + pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { + let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), batches, - ..Default::default() + predicate: None, + batch_size: None, + projection: None, + table_schema, } } } @@ -132,24 +131,19 @@ impl FileSource for TestSource { _object_store: Arc, _base_config: &FileScanConfig, _partition: usize, - ) -> Arc { - Arc::new(TestOpener { + ) -> Result> { + Ok(Arc::new(TestOpener { batches: self.batches.clone(), batch_size: self.batch_size, - schema: self.schema.clone(), projection: self.projection.clone(), predicate: self.predicate.clone(), - }) + })) } fn filter(&self) -> Option> { self.predicate.clone() } - fn as_any(&self) -> &dyn Any { - todo!("should not be called") - } - fn with_batch_size(&self, batch_size: usize) -> Arc { Arc::new(TestSource { batch_size: Some(batch_size), @@ -157,43 +151,10 @@ impl FileSource for TestSource { }) } - fn with_schema(&self, schema: TableSchema) -> Arc { - assert!( - schema.table_partition_cols().is_empty(), - "TestSource does not support partition columns" - ); - Arc::new(TestSource { - schema: Some(schema.file_schema().clone()), - ..self.clone() - }) - } - - fn with_projection(&self, config: &FileScanConfig) -> Arc { - Arc::new(TestSource { - projection: config.projection_exprs.as_ref().map(|p| p.column_indices()), - ..self.clone() - }) - } - - fn with_statistics(&self, statistics: Statistics) -> Arc { - Arc::new(TestSource { - statistics: Some(statistics), - ..self.clone() - }) - } - fn metrics(&self) -> &ExecutionPlanMetricsSet { &self.metrics } - fn statistics(&self) -> Result { - Ok(self - .statistics - .as_ref() - .expect("statistics not set") - .clone()) - } - fn file_type(&self) -> &str { "test" } @@ -247,18 +208,51 @@ impl FileSource for TestSource { } } - fn with_schema_adapter_factory( + fn try_pushdown_projection( &self, - schema_adapter_factory: Arc, - ) -> Result> { - Ok(Arc::new(Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self.clone() - })) + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } } - fn schema_adapter_factory(&self) -> Option> { - self.schema_adapter_factory.clone() + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + + fn table_schema(&self) -> &datafusion_datasource::TableSchema { + &self.table_schema + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit predicate (filter) expression if present + if let Some(predicate) = &self.predicate { + f(predicate.as_ref())?; + } + + // Visit projection expressions if present + if let Some(projection) = &self.projection { + for proj_expr in projection { + f(proj_expr.expr.as_ref())?; + } + } + + Ok(TreeNodeRecursion::Continue) } } @@ -289,14 +283,15 @@ impl TestScanBuilder { } pub fn build(self) -> Arc { - let source = Arc::new(TestSource::new(self.support, self.batches)); - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test://").unwrap(), + let source = Arc::new(TestSource::new( Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile::new("test.parquet", 123)) - .build(); + self.support, + self.batches, + )); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file(PartitionedFile::new("test.parquet", 123)) + .build(); DataSourceExec::from_data_source(base_config) } } @@ -335,11 +330,12 @@ impl TestStream { /// least one entry in data (for the schema) pub fn new(data: Vec) -> Self { // check that there is at least one entry in data and that all batches have the same schema - assert!(!data.is_empty(), "data must not be empty"); - assert!( - data.iter().all(|batch| batch.schema() == data[0].schema()), - "all batches must have the same schema" - ); + if let Some(first) = data.first() { + assert!( + data.iter().all(|batch| batch.schema() == first.schema()), + "all batches must have the same schema" + ); + } Self { data, ..Default::default() @@ -377,6 +373,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, @@ -488,11 +485,7 @@ impl ExecutionPlan for TestNode { "TestInsertExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } @@ -576,4 +569,13 @@ impl ExecutionPlan for TestNode { Ok(res) } } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit the predicate expression + f(self.predicate.as_ref())?; + Ok(TreeNodeRecursion::Continue) + } } diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 066e52614a12e..601667ea02c0d 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,10 +18,10 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_batches_exec, coalesce_partitions_exec, - create_test_schema3, parquet_exec_with_sort, sort_exec, - sort_exec_with_preserve_partitioning, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, + check_integrity, coalesce_partitions_exec, create_test_schema3, + parquet_exec_with_sort, sort_exec, sort_exec_with_preserve_partitioning, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -41,7 +41,6 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; @@ -50,8 +49,8 @@ use datafusion_physical_plan::{ collect, displayable, ExecutionPlan, Partitioning, }; +use object_store::ObjectStoreExt; use object_store::memory::InMemory; -use object_store::ObjectStore; use rstest::rstest; use url::Url; @@ -138,7 +137,8 @@ impl ReplaceTest { assert!( res.is_ok(), "Some errors occurred while executing the optimized physical plan: {:?}\nPlan: {}", - res.unwrap_err(), optimized_plan_string + res.unwrap_err(), + optimized_plan_string ); } @@ -192,7 +192,7 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -202,13 +202,13 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -218,13 +218,13 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -275,21 +275,21 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] Optimized: SortPreservingMergeExec: [a@0 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] "); }, @@ -300,11 +300,11 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC "); }, @@ -315,21 +315,21 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC Optimized: SortPreservingMergeExec: [a@0 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC "); } @@ -375,14 +375,14 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -393,7 +393,7 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -404,14 +404,14 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -439,9 +439,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); - let sort = - sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); let physical_plan = sort_preserving_merge_exec(ordering, sort); let run = ReplaceTest::new(physical_plan) @@ -457,19 +455,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -477,11 +473,10 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input / Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -489,19 +484,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -526,12 +519,9 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr, 8192); - let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); + let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter, 8192); - let sort = - sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec_2); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); let physical_plan = sort_preserving_merge_exec(ordering, sort); let run = ReplaceTest::new(physical_plan) @@ -547,21 +537,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -569,12 +555,10 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input / Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -582,21 +566,17 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( Input: SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -621,8 +601,7 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); - let physical_plan = coalesce_partitions_exec(coalesce_batches_exec); + let physical_plan = coalesce_partitions_exec(filter); let run = ReplaceTest::new(physical_plan) .with_boundedness(boundedness) @@ -636,22 +615,20 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because there is no executor with ordering requirement }, @@ -659,11 +636,10 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( assert_snapshot!(physical_plan, @r" Input / Optimized: CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -690,8 +666,7 @@ async fn test_with_multiple_replaceable_repartitions( let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter, 8192); - let repartition_hash_2 = repartition_exec_hash(coalesce_batches); + let repartition_hash_2 = repartition_exec_hash(filter); let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); let physical_plan = sort_preserving_merge_exec(ordering, sort); @@ -709,20 +684,18 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { @@ -731,11 +704,10 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, (Boundedness::Bounded, SortPreference::PreserveOrder) => { @@ -744,20 +716,18 @@ async fn test_with_multiple_replaceable_repartitions( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: c@1 > 3 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } } @@ -804,7 +774,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -814,7 +784,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because ordering requirement of the executor is @@ -826,7 +796,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -870,13 +840,13 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -886,7 +856,7 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -896,13 +866,13 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -956,22 +926,22 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [c@1 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -982,11 +952,11 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -997,22 +967,22 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [c@1 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -1040,8 +1010,6 @@ async fn test_with_multiple_child_trees( }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); - let left_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_ordering = [sort_expr("a", &schema)].into(); let right_source = match boundedness { @@ -1052,11 +1020,8 @@ async fn test_with_multiple_child_trees( }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); - let right_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(right_repartition_hash, 4096)); - let hash_join_exec = - hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); + let hash_join_exec = hash_join_exec(left_repartition_hash, right_repartition_hash); let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); let physical_plan = sort_preserving_merge_exec(ordering, sort); @@ -1075,14 +1040,12 @@ async fn test_with_multiple_child_trees( SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, (Boundedness::Bounded, _) => { @@ -1091,14 +1054,12 @@ async fn test_with_multiple_child_trees( SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. // Hence, no need to preserve existing ordering. @@ -1166,8 +1127,8 @@ fn hash_join_exec( ) -> Arc { let left_on = col("c", &left.schema()).unwrap(); let right_on = col("c", &right.schema()).unwrap(); - let left_col = left_on.as_any().downcast_ref::().unwrap(); - let right_col = right_on.as_any().downcast_ref::().unwrap(); + let left_col = left_on.downcast_ref::().unwrap(); + let right_col = right_on.downcast_ref::().unwrap(); Arc::new( HashJoinExec::try_new( left, @@ -1178,6 +1139,7 @@ fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, ) .unwrap(), ) @@ -1248,7 +1210,10 @@ fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { )], ); let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); - assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + assert_contains!( + res.unwrap_err().to_string(), + "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]" + ); // Test sort is without fetch, expected to get the fetch value from the coalesced let requirements = OrderPreservationContext::new( diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index 9867ed1733413..217570846d56e 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -30,13 +30,13 @@ use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTab use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; use datafusion_common::{JoinType, Result, ScalarValue}; -use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{Literal, col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; use async_trait::async_trait; @@ -555,11 +555,11 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] " @@ -605,8 +605,8 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 @@ -653,11 +653,11 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] " diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 8ca33f3d4abb9..6814ab2358ffc 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -17,8 +17,7 @@ //! Test utilities for physical optimizer tests -use std::any::Any; -use std::fmt::Formatter; +use std::fmt::{Display, Formatter}; use std::sync::{Arc, LazyLock}; use arrow::array::Int32Array; @@ -31,27 +30,32 @@ use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, Statistics}; +use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, Result, Statistics, internal_err, +}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{self, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; @@ -63,18 +67,17 @@ use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeE use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::tree_node::PlanContext; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec}; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, create_window_expr}; use datafusion_physical_plan::{ - displayable, DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, - Partitioning, PlanProperties, + DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, Partitioning, + PlanProperties, SortOrderPushdownResult, displayable, }; /// Create a non sorted parquet exec pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .build(); @@ -89,8 +92,7 @@ pub(crate) fn parquet_exec_with_sort( ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) @@ -106,6 +108,7 @@ fn int64_stats() -> ColumnStatistics { max_value: Precision::Exact(1_000_000.into()), min_value: Precision::Exact(0.into()), distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -127,17 +130,13 @@ pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::new(Default::default())), + Arc::new(ParquetSource::new(schema())), ) .with_file(PartitionedFile::new("x".to_string(), file_size)) .with_statistics(statistics) .build(); - assert_eq!( - config.file_source.statistics().unwrap().num_rows, - Precision::Inexact(10000) - ); + assert_eq!(config.statistics().num_rows, Precision::Inexact(10000)); DataSourceExec::from_data_source(config) } @@ -249,6 +248,7 @@ pub fn hash_join_exec( None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?)) } @@ -361,13 +361,6 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec( - input: Arc, - batch_size: usize, -) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, batch_size)) -} - pub fn sort_exec( ordering: LexOrdering, input: Arc, @@ -458,19 +451,16 @@ impl ExecutionPlan for RequirementsTestExec { "RequiredInputOrderingExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } fn required_input_ordering(&self) -> Vec> { - vec![self - .required_input_ordering - .as_ref() - .map(|ordering| OrderingRequirements::from(ordering.clone()))] + vec![ + self.required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone())), + ] } fn maintains_input_order(&self) -> Vec { @@ -499,6 +489,20 @@ impl ExecutionPlan for RequirementsTestExec { ) -> Result { unimplemented!("Test exec does not support execution") } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in required_input_ordering if present + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = &self.required_input_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } /// A [`PlanContext`] object is susceptible to being left in an inconsistent state after @@ -704,3 +708,361 @@ impl TestAggregate { } } } + +/// A harness for testing physical optimizers. +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + enable_sort_pushdown: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_sort_pushdown = enable_sort_pushdown; + let output_result = opt.optimize(input_plan, &config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +/// Create a simple ProjectionExec with column indices (simplified version) +pub fn simple_projection_exec( + input: Arc, + columns: Vec, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&i| { + let field = schema.field(i); + ( + Arc::new(expressions::Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a ProjectionExec with column aliases +pub fn projection_exec_with_alias( + input: Arc, + columns: Vec<(usize, &str)>, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&(i, alias)| { + ( + Arc::new(expressions::Column::new(schema.field(i).name(), i)) + as Arc, + alias.to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a sort expression with custom name and index +pub fn sort_expr_named(name: &str, index: usize) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: Arc::new(expressions::Column::new(name, index)), + options: SortOptions::default(), + } +} + +/// A test data source that can display any requested ordering. +/// This is useful for testing sort pushdown behavior. +/// +/// ## Configuration +/// +/// - `exact_pushdown`: if `true`, `try_pushdown_sort` returns `Exact` +/// (source guarantees ordering, SortExec can be removed); if `false` +/// (default), returns `Inexact` (SortExec kept). +/// - `supports_fetch`: if `true`, `with_fetch()` returns `Some` so the +/// optimizer can push a LIMIT into the source; if `false` (default), +/// `with_fetch()` returns `None`, forcing a `GlobalLimitExec` wrapper. +#[derive(Debug, Clone)] +pub struct TestScan { + schema: SchemaRef, + output_ordering: Vec, + plan_properties: Arc, + // Store the requested ordering for display + requested_ordering: Option, + /// If true, `try_pushdown_sort` returns `Exact` instead of `Inexact`. + exact_pushdown: bool, + /// If true, `with_fetch()` returns `Some(...)` (source absorbs the limit). + supports_fetch: bool, + /// The fetch (LIMIT) value pushed into this scan via `with_fetch()`. + fetch: Option, +} + +impl TestScan { + /// Create a new TestScan with the given schema and output ordering + pub fn new(schema: SchemaRef, output_ordering: Vec) -> Self { + let eq_properties = if !output_ordering.is_empty() { + // Convert Vec to the format expected by new_with_orderings + // We need to extract the inner Vec from each LexOrdering + let orderings: Vec> = output_ordering + .iter() + .map(|lex_ordering| { + // LexOrdering implements IntoIterator, so we can collect it + lex_ordering.iter().cloned().collect() + }) + .collect(); + + EquivalenceProperties::new_with_orderings(Arc::clone(&schema), orderings) + } else { + EquivalenceProperties::new(Arc::clone(&schema)) + }; + + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Self { + schema, + output_ordering, + plan_properties: Arc::new(plan_properties), + requested_ordering: None, + exact_pushdown: false, + supports_fetch: false, + fetch: None, + } + } + + /// Create a TestScan with a single output ordering + pub fn with_ordering(schema: SchemaRef, ordering: LexOrdering) -> Self { + Self::new(schema, vec![ordering]) + } + + /// Set whether `try_pushdown_sort` returns `Exact` (true) or `Inexact` (false). + pub fn with_exact_pushdown(mut self, exact: bool) -> Self { + self.exact_pushdown = exact; + self + } + + /// Set whether `with_fetch()` returns `Some` (true) or `None` (false). + pub fn with_supports_fetch(mut self, supports: bool) -> Self { + self.supports_fetch = supports; + self + } +} + +impl DisplayAs for TestScan { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "TestScan")?; + let mut sep = ": "; + if !self.output_ordering.is_empty() { + write!(f, "{sep}output_ordering=[")?; + for (i, sort_expr) in self.output_ordering[0].iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + sep = ", "; + } + if let Some(ref req) = self.requested_ordering { + write!(f, "{sep}requested_ordering=[")?; + for (i, sort_expr) in req.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + sep = ", "; + } + if let Some(fetch) = self.fetch { + write!(f, "{sep}fetch={fetch}")?; + } + Ok(()) + } + DisplayFormatType::TreeRender => { + write!(f, "TestScan") + } + } + } +} + +impl ExecutionPlan for TestScan { + fn name(&self) -> &str { + "TestScan" + } + + fn properties(&self) -> &Arc { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(self) + } else { + internal_err!("TestScan should have no children") + } + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + internal_err!("TestScan is for testing optimizer only, not for execution") + } + + fn partition_statistics(&self, _partition: Option) -> Result> { + Ok(Arc::new(Statistics::new_unknown(&self.schema))) + } + + fn with_fetch(&self, fetch: Option) -> Option> { + if self.supports_fetch { + let mut new_scan = self.clone(); + new_scan.fetch = fetch; + Some(Arc::new(new_scan)) + } else { + None + } + } + + fn fetch(&self) -> Option { + self.fetch + } + + // This is the key method - implement sort pushdown + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // For testing purposes, accept ANY ordering request + // and create a new TestScan that shows what was requested + let requested_ordering = LexOrdering::new(order.to_vec()); + + let mut new_scan = self.clone(); + new_scan.requested_ordering = requested_ordering; + + if self.exact_pushdown { + // Update plan properties to reflect the guaranteed ordering + let orderings: Vec> = vec![order.to_vec()]; + let eq_properties = EquivalenceProperties::new_with_orderings( + Arc::clone(&self.schema), + orderings, + ); + new_scan.plan_properties = Arc::new(PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(new_scan), + }) + } else { + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(new_scan), + }) + } + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in output_ordering + let mut tnr = TreeNodeRecursion::Continue; + for ordering in &self.output_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + + // Visit expressions in requested_ordering if present + if let Some(ordering) = &self.requested_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + + Ok(tnr) + } +} + +/// Helper function to create a TestScan with ordering +pub fn test_scan_with_ordering( + schema: SchemaRef, + ordering: LexOrdering, +) -> Arc { + Arc::new(TestScan::with_ordering(schema, ordering)) +} diff --git a/datafusion/core/tests/physical_optimizer/window_optimize.rs b/datafusion/core/tests/physical_optimizer/window_optimize.rs index fc1e6444d756e..796f6b6259716 100644 --- a/datafusion/core/tests/physical_optimizer/window_optimize.rs +++ b/datafusion/core/tests/physical_optimizer/window_optimize.rs @@ -26,10 +26,10 @@ mod test { use datafusion_expr::WindowFrame; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr::window::PlainAggregateWindowExpr; use datafusion_physical_plan::windows::BoundedWindowAggExec; - use datafusion_physical_plan::{common, ExecutionPlan, InputOrderMode}; + use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, common}; use std::sync::Arc; /// Test case for diff --git a/datafusion/core/tests/physical_optimizer/window_topn.rs b/datafusion/core/tests/physical_optimizer/window_topn.rs new file mode 100644 index 0000000000000..e3f73a85353cc --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/window_topn.rs @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the WindowTopN physical optimizer rule. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::Operator; +use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +use datafusion_functions_window::row_number::row_number_udwf; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, col, lit}; +use datafusion_physical_expr::window::StandardWindowExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::window_topn::WindowTopN; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, create_udwf_window_expr}; +use datafusion_physical_plan::{ExecutionPlan, InputOrderMode}; +use insta::assert_snapshot; + +fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("pk", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])) +} + +fn plan_str(plan: &dyn ExecutionPlan) -> String { + displayable(plan).indent(true).to_string() +} + +fn optimize(plan: Arc) -> Result> { + let mut config = ConfigOptions::new(); + config.optimizer.enable_window_topn = true; + WindowTopN::new().optimize(plan, &config) +} + +fn optimize_disabled(plan: Arc) -> Result> { + let mut config = ConfigOptions::new(); + config.optimizer.enable_window_topn = false; + WindowTopN::new().optimize(plan, &config) +} + +/// Build: FilterExec(rn <= limit) → BoundedWindowAggExec(ROW_NUMBER PBY pk OBY val) → SortExec(pk, val) +fn build_window_topn_plan( + limit_value: i64, + op: Operator, +) -> Result> { + let s = schema(); + let input: Arc = Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + // Sort by pk ASC, val ASC + let ordering = LexOrdering::new(vec![ + PhysicalSortExpr::new_default(col("pk", &s)?).asc(), + PhysicalSortExpr::new_default(col("val", &s)?).asc(), + ]) + .unwrap(); + + let sort: Arc = + Arc::new(SortExec::new(ordering.clone(), input).with_preserve_partitioning(true)); + + // ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) + let partition_by = vec![col("pk", &s)?]; + let order_by = vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]; + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &partition_by, + &order_by, + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + // FilterExec: rn op limit_value + // The ROW_NUMBER column is at index 2 (after pk=0, val=1) + let rn_col = Arc::new(Column::new("row_number", 2)); + let limit_lit = lit(ScalarValue::UInt64(Some(limit_value as u64))); + let predicate = Arc::new(BinaryExpr::new(rn_col, op, limit_lit)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, window)?); + + Ok(filter) +} + +/// Build a plan with no partition-by: ROW_NUMBER() OVER (ORDER BY val) +fn build_window_topn_no_partition(limit_value: i64) -> Result> { + let s = schema(); + let input: Arc = Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + // Sort by val ASC only (no partition key) + let ordering = + LexOrdering::new(vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]) + .unwrap(); + + let sort: Arc = + Arc::new(SortExec::new(ordering.clone(), input).with_preserve_partitioning(true)); + + // ROW_NUMBER() OVER (ORDER BY val) — no partition by + let order_by = vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]; + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &[], // empty partition_by + &order_by, + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + let rn_col = Arc::new(Column::new("row_number", 2)); + let limit_lit = lit(ScalarValue::UInt64(Some(limit_value as u64))); + let predicate = Arc::new(BinaryExpr::new(rn_col, Operator::LtEq, limit_lit)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, window)?); + + Ok(filter) +} + +/// Build a plan where filter is on a data column (not window output) +fn build_non_window_filter_plan() -> Result> { + let s = schema(); + let input: Arc = Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + let ordering = LexOrdering::new(vec![ + PhysicalSortExpr::new_default(col("pk", &s)?).asc(), + PhysicalSortExpr::new_default(col("val", &s)?).asc(), + ]) + .unwrap(); + + let sort: Arc = + Arc::new(SortExec::new(ordering.clone(), input).with_preserve_partitioning(true)); + + let partition_by = vec![col("pk", &s)?]; + let order_by = vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]; + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &partition_by, + &order_by, + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + // Filter on data column val (index 1), NOT on window output + let val_col = Arc::new(Column::new("val", 1)); + let limit_lit = lit(ScalarValue::Int64(Some(3))); + let predicate = Arc::new(BinaryExpr::new(val_col, Operator::LtEq, limit_lit)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, window)?); + + Ok(filter) +} + +#[test] +fn basic_row_number_rn_lteq_3() -> Result<()> { + let plan = build_window_topn_plan(3, Operator::LtEq)?; + let optimized = optimize(plan)?; + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + PartitionedTopKExec: fetch=3, partition=[pk@0], order=[val@1 ASC] + PlaceholderRowExec + "#); + Ok(()) +} + +#[test] +fn rn_lt_3_becomes_fetch_2() -> Result<()> { + let plan = build_window_topn_plan(3, Operator::Lt)?; + let optimized = optimize(plan)?; + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + PartitionedTopKExec: fetch=2, partition=[pk@0], order=[val@1 ASC] + PlaceholderRowExec + "#); + Ok(()) +} + +#[test] +fn flipped_3_gteq_rn() -> Result<()> { + let plan = { + let s = schema(); + let input: Arc = + Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + let ordering = LexOrdering::new(vec![ + PhysicalSortExpr::new_default(col("pk", &s)?).asc(), + PhysicalSortExpr::new_default(col("val", &s)?).asc(), + ]) + .unwrap(); + + let sort: Arc = Arc::new( + SortExec::new(ordering.clone(), input).with_preserve_partitioning(true), + ); + + let partition_by = vec![col("pk", &s)?]; + let order_by = vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]; + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &partition_by, + &order_by, + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + // Flipped: 3 >= rn (Literal GtEq Column) + let rn_col = Arc::new(Column::new("row_number", 2)); + let limit_lit = lit(ScalarValue::UInt64(Some(3))); + let predicate = Arc::new(BinaryExpr::new(limit_lit, Operator::GtEq, rn_col)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, window)?); + filter + }; + + let optimized = optimize(plan)?; + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + PartitionedTopKExec: fetch=3, partition=[pk@0], order=[val@1 ASC] + PlaceholderRowExec + "#); + Ok(()) +} + +#[test] +fn non_window_column_filter_no_change() -> Result<()> { + let plan = build_non_window_filter_plan()?; + let before = plan_str(plan.as_ref()); + let optimized = optimize(plan)?; + let after = plan_str(optimized.as_ref()); + assert_eq!( + before, after, + "Plan should not change when filter is on data column" + ); + Ok(()) +} + +#[test] +fn config_disabled_no_change() -> Result<()> { + let plan = build_window_topn_plan(3, Operator::LtEq)?; + let before = plan_str(plan.as_ref()); + let optimized = optimize_disabled(plan)?; + let after = plan_str(optimized.as_ref()); + assert_eq!( + before, after, + "Plan should not change when config is disabled" + ); + Ok(()) +} + +#[test] +fn no_partition_by_no_change() -> Result<()> { + // Without PARTITION BY, this is a global top-K which SortExec with + // fetch already handles — the rule should not fire. + let plan = build_window_topn_no_partition(5)?; + let optimized = optimize(plan)?; + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + FilterExec: row_number@2 <= 5 + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[val@1 ASC], preserve_partitioning=[true] + PlaceholderRowExec + "#); + Ok(()) +} + +#[test] +fn with_projection_between() -> Result<()> { + let s = schema(); + let input: Arc = Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + let ordering = LexOrdering::new(vec![ + PhysicalSortExpr::new_default(col("pk", &s)?).asc(), + PhysicalSortExpr::new_default(col("val", &s)?).asc(), + ]) + .unwrap(); + + let sort: Arc = + Arc::new(SortExec::new(ordering.clone(), input).with_preserve_partitioning(true)); + + let partition_by = vec![col("pk", &s)?]; + let order_by = vec![PhysicalSortExpr::new_default(col("val", &s)?).asc()]; + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &partition_by, + &order_by, + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + // Add a ProjectionExec between Filter and Window + let window_schema = window.schema(); + let proj_exprs: Vec<(Arc, String)> = + window_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), i)) + as Arc, + f.name().to_string(), + ) + }) + .collect(); + + let projection: Arc = + Arc::new(ProjectionExec::try_new(proj_exprs, window)?); + + // rn column is still at index 2 in the projected schema + let rn_col = Arc::new(Column::new("row_number", 2)); + let limit_lit = lit(ScalarValue::UInt64(Some(3))); + let predicate = Arc::new(BinaryExpr::new(rn_col, Operator::LtEq, limit_lit)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, projection)?); + + let optimized = optimize(filter)?; + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + ProjectionExec: expr=[pk@0 as pk, val@1 as val, row_number@2 as row_number] + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + PartitionedTopKExec: fetch=3, partition=[pk@0], order=[val@1 ASC] + PlaceholderRowExec + "#); + Ok(()) +} diff --git a/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs b/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs deleted file mode 100644 index c3c92a9028d67..0000000000000 --- a/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs +++ /dev/null @@ -1,363 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use arrow::array::RecordBatch; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use bytes::{BufMut, BytesMut}; -use datafusion::common::Result; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::{ - ArrowSource, CsvSource, FileSource, JsonSource, ParquetSource, -}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; -use datafusion_common::ColumnStatistics; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, -}; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::object_store::ObjectStoreUrl; -use object_store::{memory::InMemory, path::Path, ObjectStore}; -use parquet::arrow::ArrowWriter; - -async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { - let mut out = BytesMut::new().writer(); - { - let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - } - let data = out.into_inner().freeze(); - store.put(&Path::from(path), data.into()).await.unwrap(); -} - -/// A schema adapter factory that transforms column names to uppercase -#[derive(Debug, PartialEq)] -struct UppercaseAdapterFactory {} - -impl SchemaAdapterFactory for UppercaseAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(UppercaseAdapter { - table_schema: projected_table_schema, - }) - } -} - -/// Schema adapter that transforms column names to uppercase -#[derive(Debug)] -struct UppercaseAdapter { - table_schema: SchemaRef, -} - -impl SchemaAdapter for UppercaseAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - let uppercase_name = field.name().to_uppercase(); - file_schema - .fields() - .iter() - .position(|f| f.name().to_uppercase() == uppercase_name) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let mut projection = Vec::new(); - - // Map each field in the table schema to the corresponding field in the file schema - for table_field in self.table_schema.fields() { - let uppercase_name = table_field.name().to_uppercase(); - if let Some(pos) = file_schema - .fields() - .iter() - .position(|f| f.name().to_uppercase() == uppercase_name) - { - projection.push(pos); - } - } - - let mapper = UppercaseSchemaMapper { - output_schema: self.output_schema(), - projection: projection.clone(), - }; - - Ok((Arc::new(mapper), projection)) - } -} - -impl UppercaseAdapter { - fn output_schema(&self) -> SchemaRef { - let fields: Vec = self - .table_schema - .fields() - .iter() - .map(|f| { - Field::new( - f.name().to_uppercase().as_str(), - f.data_type().clone(), - f.is_nullable(), - ) - }) - .collect(); - - Arc::new(Schema::new(fields)) - } -} - -#[derive(Debug)] -struct UppercaseSchemaMapper { - output_schema: SchemaRef, - projection: Vec, -} - -impl SchemaMapper for UppercaseSchemaMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - let columns = self - .projection - .iter() - .map(|&i| batch.column(i).clone()) - .collect::>(); - Ok(RecordBatch::try_new(self.output_schema.clone(), columns)?) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - Ok(self - .projection - .iter() - .map(|&i| stats.get(i).cloned().unwrap_or_default()) - .collect()) - } -} - -#[cfg(feature = "parquet")] -#[tokio::test] -async fn test_parquet_integration_with_schema_adapter() -> Result<()> { - // Create test data - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])), - vec![ - Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), - Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), - ], - )?; - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch.clone(), store.clone(), path).await; - - // Get the actual file size from the object store - let object_meta = store.head(&Path::from(path)).await?; - let file_size = object_meta.size; - - // Create a session context and register the object store - let ctx = SessionContext::new(); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - - // Create a ParquetSource with the adapter factory - let file_source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; - - // Create a table schema with uppercase column names - let table_schema = Arc::new(Schema::new(vec![ - Field::new("ID", DataType::Int32, false), - Field::new("NAME", DataType::Utf8, true), - ])); - - let config = FileScanConfigBuilder::new(store_url, table_schema.clone(), file_source) - .with_file(PartitionedFile::new(path, file_size)) - .build(); - - // Create a data source executor - let exec = DataSourceExec::from_data_source(config); - - // Collect results - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx)?; - let batches = datafusion::physical_plan::common::collect(stream).await?; - - // There should be one batch - assert_eq!(batches.len(), 1); - - // Verify the schema has the uppercase column names - let result_schema = batches[0].schema(); - assert_eq!(result_schema.field(0).name(), "ID"); - assert_eq!(result_schema.field(1).name(), "NAME"); - - Ok(()) -} - -#[cfg(feature = "parquet")] -#[tokio::test] -async fn test_parquet_integration_with_schema_adapter_and_expression_rewriter( -) -> Result<()> { - // Create test data - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])), - vec![ - Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), - Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), - ], - )?; - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch.clone(), store.clone(), path).await; - - // Get the actual file size from the object store - let object_meta = store.head(&Path::from(path)).await?; - let file_size = object_meta.size; - - // Create a session context and register the object store - let ctx = SessionContext::new(); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - - // Create a ParquetSource with the adapter factory - let file_source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; - - let config = FileScanConfigBuilder::new(store_url, batch.schema(), file_source) - .with_file(PartitionedFile::new(path, file_size)) - .build(); - - // Create a data source executor - let exec = DataSourceExec::from_data_source(config); - - // Collect results - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx)?; - let batches = datafusion::physical_plan::common::collect(stream).await?; - - // There should be one batch - assert_eq!(batches.len(), 1); - - // Verify the schema has the original column names (schema adapter not applied in DataSourceExec) - let result_schema = batches[0].schema(); - assert_eq!(result_schema.field(0).name(), "id"); - assert_eq!(result_schema.field(1).name(), "name"); - - Ok(()) -} - -#[tokio::test] -async fn test_multi_source_schema_adapter_reuse() -> Result<()> { - // This test verifies that the same schema adapter factory can be reused - // across different file source types. This is important for ensuring that: - // 1. The schema adapter factory interface works uniformly across all source types - // 2. The factory can be shared and cloned efficiently using Arc - // 3. Various data source implementations correctly implement the schema adapter factory pattern - - // Create a test factory - let factory = Arc::new(UppercaseAdapterFactory {}); - - // Test ArrowSource - { - let source = ArrowSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test ParquetSource - #[cfg(feature = "parquet")] - { - let source = ParquetSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test CsvSource - { - let source = CsvSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test JsonSource - { - let source = JsonSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - Ok(()) -} diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 0000000000000..464d6c937b328 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 4b421b5294e01..3e5dc6a0b1872 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -365,7 +365,7 @@ async fn count_distinct_dictionary_all_null_values() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+---------------+ | cnt | count(t.num2) | +-----+---------------+ @@ -375,7 +375,7 @@ async fn count_distinct_dictionary_all_null_values() -> Result<()> { | 0 | 1 | | 0 | 1 | +-----+---------------+ - "### + " ); // Test with multiple partitions @@ -430,13 +430,68 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------------------------+ | count(DISTINCT t.dict) | +------------------------+ | 2 | +------------------------+ - "### + " + ); + + Ok(()) +} + +#[tokio::test] +async fn group_by_ree_dict_column() -> Result<()> { + let ctx = SessionContext::new(); + + let run_ends = Int32Array::from(vec![2, 4, 5]); + let dict = DictionaryArray::new( + UInt32Array::from(vec![0, 1, 2]), + Arc::new(StringArray::from(vec!["alpha", "beta", "gamma"])), + ); + let ree_col = RunArray::::try_new(&run_ends, &dict).unwrap(); + let value_col = Int32Array::from(vec![1, 2, 3, 4, 5]); + + let dict_type = + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![ + Field::new( + "group_col", + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", dict_type, true)), + ), + true, + ), + Field::new("value", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ree_col), Arc::new(value_col)], + )?; + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + let results = ctx + .sql("SELECT group_col, SUM(value) as total FROM t GROUP BY group_col ORDER BY group_col") + .await? + .collect() + .await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +-----------+-------+ + | group_col | total | + +-----------+-------+ + | alpha | 3 | + | beta | 7 | + | gamma | 5 | + +-----------+-------+ + " ); Ok(()) diff --git a/datafusion/core/tests/sql/aggregates/dict_nulls.rs b/datafusion/core/tests/sql/aggregates/dict_nulls.rs index da4b2c8d25c9d..f9e15a71a20f8 100644 --- a/datafusion/core/tests/sql/aggregates/dict_nulls.rs +++ b/datafusion/core/tests/sql/aggregates/dict_nulls.rs @@ -34,7 +34,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { assert_snapshot!( batches_to_string(&results_count), - @r###" + @r" +----------------+-----+ | dict_null_keys | cnt | +----------------+-----+ @@ -42,7 +42,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { | group_a | 2 | | group_b | 1 | +----------------+-----+ - "### + " ); // Test SUM null handling with extended data @@ -69,7 +69,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { assert_snapshot!( batches_to_string(&results_min), - @r###" + @r" +----------------+---------+ | dict_null_keys | minimum | +----------------+---------+ @@ -78,7 +78,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { | group_b | 1 | | group_c | 7 | +----------------+---------+ - "### + " ); // Test MEDIAN null handling with median data @@ -168,7 +168,7 @@ async fn test_first_last_value_order_by_null_handling() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------------+-------+--------------------+---------------------+-------------------+--------------------+ | dict_group | value | first_ignore_nulls | first_respect_nulls | last_ignore_nulls | last_respect_nulls | +------------+-------+--------------------+---------------------+-------------------+--------------------+ @@ -178,7 +178,7 @@ async fn test_first_last_value_order_by_null_handling() -> Result<()> { | group_a | | 5 | | 20 | | | group_b | | 5 | | 20 | | +------------+-------+--------------------+---------------------+-------------------+--------------------+ - "### + " ); Ok(()) @@ -249,7 +249,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----------------+-----------+----------+-----+ | dict_null_keys | first_val | last_val | cnt | +----------------+-----------+----------+-----+ @@ -257,7 +257,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { | group_a | 10 | 50 | 2 | | group_b | 30 | 30 | 1 | +----------------+-----------+----------+-----+ - "### + " ); // Test GROUP BY with null values in dictionary @@ -275,7 +275,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results2), - @r###" + @r" +----------------+-----------+----------+-----+ | dict_null_vals | first_val | last_val | cnt | +----------------+-----------+----------+-----+ @@ -283,7 +283,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { | val_x | 10 | 50 | 2 | | val_y | 30 | 30 | 1 | +----------------+-----------+----------+-----+ - "### + " ); Ok(()) @@ -394,7 +394,7 @@ async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +--------+----------+---------------------+------+------+ | u8_low | utf8_low | dictionary_utf8_low | col1 | col2 | +--------+----------+---------------------+------+------+ @@ -405,7 +405,7 @@ async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { | 20 | text_e | | 0 | 1 | | 25 | text_f | group_gamma | 1 | 1 | +--------+----------+---------------------+------+------+ - "### + " ); Ok(()) diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs index 321c158628e43..ede40d5c4ceca 100644 --- a/datafusion/core/tests/sql/aggregates/mod.rs +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -20,15 +20,15 @@ use super::*; use arrow::{ array::{ - types::UInt32Type, Decimal128Array, DictionaryArray, DurationNanosecondArray, - Int32Array, LargeBinaryArray, StringArray, TimestampMicrosecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, DurationNanosecondArray, Int32Array, + LargeBinaryArray, StringArray, TimestampMicrosecondArray, UInt8Array, + UInt16Array, UInt32Array, UInt64Array, types::UInt32Type, }, datatypes::{DataType, Field, Schema, TimeUnit}, record_batch::RecordBatch, }; use datafusion::{ - common::{test_util::batches_to_string, Result}, + common::{Result, test_util::batches_to_string}, execution::{config::SessionConfig, context::SessionContext}, }; use datafusion_catalog::MemTable; @@ -959,8 +959,8 @@ impl FuzzTimestampTestData { } /// Sets up test contexts for fuzz table with timestamps and both single and multiple partitions -pub async fn setup_fuzz_timestamp_test_contexts( -) -> Result<(SessionContext, SessionContext)> { +pub async fn setup_fuzz_timestamp_test_contexts() +-> Result<(SessionContext, SessionContext)> { let test_data = FuzzTimestampTestData::new(); // Single partition context diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 8d98b91547fe7..8ab0d150a7272 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -22,7 +22,7 @@ use rstest::rstest; use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; -use datafusion_common::format::ExplainAnalyzeLevel; +use datafusion_common::format::{ExplainAnalyzeCategories, MetricCategory, MetricType}; use object_store::path::Path; #[tokio::test] @@ -61,87 +61,92 @@ async fn explain_analyze_baseline_metrics() { assert_metrics!( &formatted, "AggregateExec: mode=Partial, gby=[]", - "metrics=[output_rows=3, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "AggregateExec: mode=Partial, gby=[]", - "output_bytes=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); assert_metrics!( &formatted, "AggregateExec: mode=Partial, gby=[c1@0 as c1]", - "reduction_factor=5.1% (5/99)" + "reduction_factor=5.05% (5/99)" ); - assert_metrics!( - &formatted, - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "output_bytes=" - ); - assert_metrics!( - &formatted, - "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "metrics=[output_rows=99, elapsed_compute=" - ); + { + let expected_batch_count_after_repartition = + if cfg!(not(feature = "force_hash_collisions")) { + "output_batches=3" + } else { + "output_batches=1" + }; + + assert_metrics!( + &formatted, + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "RepartitionExec: partitioning=Hash([c1@0], 3), input_partitions=3", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "ProjectionExec: expr=[]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + } + assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "output_bytes=" + "metrics=[output_rows=99, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); + assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", "selectivity=99% (99/100)" ); - assert_metrics!( - &formatted, - "ProjectionExec: expr=[]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!(&formatted, "ProjectionExec: expr=[]", "output_bytes="); - assert_metrics!( - &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "metrics=[output_rows=5, elapsed_compute" - ); - assert_metrics!( - &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "output_bytes=" - ); + assert_metrics!( &formatted, "UnionExec", - "metrics=[output_rows=3, elapsed_compute=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); - assert_metrics!(&formatted, "UnionExec", "output_bytes="); + assert_metrics!( &formatted, "WindowAggExec", - "metrics=[output_rows=1, elapsed_compute=" + "metrics=[output_rows=1, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); - assert_metrics!(&formatted, "WindowAggExec", "output_bytes="); fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; use datafusion::physical_plan::sorts; - plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() + plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() } // Validate that the recorded elapsed compute time was more than @@ -200,7 +205,7 @@ fn nanos_from_timestamp(ts: &Timestamp) -> i64 { async fn collect_plan_with_context( sql_str: &str, ctx: &SessionContext, - level: ExplainAnalyzeLevel, + level: MetricType, ) -> String { { let state = ctx.state_ref(); @@ -214,7 +219,24 @@ async fn collect_plan_with_context( .to_string() } -async fn collect_plan(sql_str: &str, level: ExplainAnalyzeLevel) -> String { +async fn collect_plan_with_categories( + sql_str: &str, + categories: ExplainAnalyzeCategories, +) -> String { + let ctx = SessionContext::new(); + { + let state = ctx.state_ref(); + let mut state = state.write(); + state.config_mut().options_mut().explain.analyze_categories = categories; + } + let dataframe = ctx.sql(sql_str).await.unwrap(); + let batches = dataframe.collect().await.unwrap(); + arrow::util::pretty::pretty_format_batches(&batches) + .unwrap() + .to_string() +} + +async fn collect_plan(sql_str: &str, level: MetricType) -> String { let ctx = SessionContext::new(); collect_plan_with_context(sql_str, &ctx, level).await } @@ -227,10 +249,14 @@ async fn explain_analyze_level() { ORDER BY v1 DESC"; for (level, needle, should_contain) in [ - (ExplainAnalyzeLevel::Summary, "spill_count", false), - (ExplainAnalyzeLevel::Summary, "output_rows", true), - (ExplainAnalyzeLevel::Dev, "spill_count", true), - (ExplainAnalyzeLevel::Dev, "output_rows", true), + (MetricType::Summary, "spill_count", false), + (MetricType::Summary, "output_batches", false), + (MetricType::Summary, "output_rows", true), + (MetricType::Summary, "output_bytes", true), + (MetricType::Dev, "spill_count", true), + (MetricType::Dev, "output_rows", true), + (MetricType::Dev, "output_bytes", true), + (MetricType::Dev, "output_batches", true), ] { let plan = collect_plan(sql, level).await; assert_eq!( @@ -254,10 +280,10 @@ async fn explain_analyze_level_datasource_parquet() { .expect("register parquet table for explain analyze test"); for (level, needle, should_contain) in [ - (ExplainAnalyzeLevel::Summary, "metadata_load_time", true), - (ExplainAnalyzeLevel::Summary, "page_index_eval_time", false), - (ExplainAnalyzeLevel::Dev, "metadata_load_time", true), - (ExplainAnalyzeLevel::Dev, "page_index_eval_time", true), + (MetricType::Summary, "metadata_load_time", true), + (MetricType::Summary, "page_index_eval_time", false), + (MetricType::Dev, "metadata_load_time", true), + (MetricType::Dev, "page_index_eval_time", true), ] { let plan = collect_plan_with_context(&sql, &ctx, level).await; @@ -290,8 +316,7 @@ async fn explain_analyze_parquet_pruning_metrics() { "explain analyze select * from {table_name} where l_orderkey = {l_orderkey};" ); - let plan = - collect_plan_with_context(&sql, &ctx, ExplainAnalyzeLevel::Summary).await; + let plan = collect_plan_with_context(&sql, &ctx, MetricType::Summary).await; let expected_metrics = format!("files_ranges_pruned_statistics={expected_pruning_metrics}"); @@ -336,12 +361,12 @@ async fn csv_explain_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int64(10) TableScan: aggregate_test_100 - "### + " ); // // verify the grahviz format of the plan @@ -407,13 +432,12 @@ async fn csv_explain_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int8(10) TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] - - "### + " ); // // verify the grahviz format of the plan @@ -553,12 +577,12 @@ async fn csv_explain_verbose_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int64(10) TableScan: aggregate_test_100 - "### + " ); // // verify the grahviz format of the plan @@ -624,12 +648,12 @@ async fn csv_explain_verbose_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int8(10) TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] - "### + " ); // // verify the grahviz format of the plan @@ -748,19 +772,17 @@ async fn test_physical_plan_display_indent() { assert_snapshot!( actual, - @r###" + @r" SortPreservingMergeExec: [the_min@2 DESC], fetch=10 SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true] ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min] AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 - AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] - CoalesceBatchesExec: target_batch_size=4096 - FilterExec: c12@1 < 10 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true - "### + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + FilterExec: c12@1 < 10 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true + " ); } @@ -794,19 +816,13 @@ async fn test_physical_plan_display_indent_multi_children() { assert_snapshot!( actual, - @r###" - CoalesceBatchesExec: target_batch_size=4096 - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - ProjectionExec: expr=[c1@0 as c2] - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true - "### + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1@0 as c2], file_type=csv, has_header=true + " ); } @@ -845,8 +861,7 @@ async fn csv_explain_analyze_order_by() { // Ensure that the ordering is not optimized away from the plan // https://github.com/apache/datafusion/issues/6379 - let needle = - "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; + let needle = "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; assert_contains!(&formatted, needle); } @@ -872,6 +887,8 @@ async fn parquet_explain_analyze() { &formatted, "row_groups_pruned_statistics=1 total \u{2192} 1 matched" ); + assert_contains!(&formatted, "output_rows_skew=0%"); + assert_contains!(&formatted, "scan_efficiency_ratio=13.99%"); // The order of metrics is expected to be the same as the actual pruning order // (file-> row-group -> page) @@ -879,13 +896,14 @@ async fn parquet_explain_analyze() { let i_rowgroup_stat = formatted.find("row_groups_pruned_statistics").unwrap(); let i_rowgroup_bloomfilter = formatted.find("row_groups_pruned_bloom_filter").unwrap(); - let i_page = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_rows = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_pages = formatted.find("page_index_pages_pruned").unwrap(); assert!( (i_file < i_rowgroup_stat) && (i_rowgroup_stat < i_rowgroup_bloomfilter) - && (i_rowgroup_bloomfilter < i_page), - "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." + && (i_rowgroup_bloomfilter < i_page_pages && i_page_pages < i_page_rows), + "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." ); } @@ -997,16 +1015,14 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec ProjectionExec: expr=[id@0 as id, 1 as level] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + FilterExec: id@0 = 1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] CoalescePartitionsExec ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 < 10 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - WorkTableExec: name=number_series + FilterExec: id@0 < 10 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + WorkTableExec: name=number_series " ); @@ -1082,11 +1098,11 @@ async fn explain_physical_plan_only() { assert_snapshot!( actual, - @r###" + @r" physical_plan ProjectionExec: expr=[2 as count(*)] PlaceholderRowExec - "### + " ); } @@ -1110,3 +1126,144 @@ async fn csv_explain_analyze_with_statistics() { ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" ); } + +#[tokio::test] +async fn nested_loop_join_selectivity() { + for (join_type, expected_selectivity) in [ + ("INNER", "1% (1/100)"), + ("LEFT", "10% (10/100)"), + ("RIGHT", "10% (10/100)"), + // 1 match + 9 left + 9 right = 19 + ("FULL", "19% (19/100)"), + ] { + let ctx = SessionContext::new(); + let sql = format!( + "EXPLAIN ANALYZE SELECT * \ + FROM generate_series(1, 10) as t1(a) \ + {join_type} JOIN generate_series(1, 10) as t2(b) \ + ON (t1.a + t2.b) = 20" + ); + + let actual = execute_to_batches(&ctx, sql.as_str()).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual) + .unwrap() + .to_string(); + + assert_metrics!( + &formatted, + "NestedLoopJoinExec", + &format!("selectivity={expected_selectivity}") + ); + } +} + +#[tokio::test] +async fn explain_analyze_hash_join() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(a) \ + JOIN generate_series(20) as t2(b) \ + ON t1.a=t2.b"; + + for (level, needle, should_contain) in [ + (MetricType::Summary, "probe_hit_rate", true), + (MetricType::Summary, "avg_fanout", true), + ] { + let plan = collect_plan(sql, level).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} + +#[tokio::test] +async fn explain_analyze_categories() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(v1) \ + ORDER BY v1 DESC"; + + for (categories, needle, should_contain) in [ + // "rows" category: output_rows yes, elapsed_compute no, output_bytes no + ( + ExplainAnalyzeCategories::Only(vec![MetricCategory::Rows]), + "output_rows", + true, + ), + ( + ExplainAnalyzeCategories::Only(vec![MetricCategory::Rows]), + "elapsed_compute", + false, + ), + ( + ExplainAnalyzeCategories::Only(vec![MetricCategory::Rows]), + "output_bytes", + false, + ), + // "none" — plan only, no metrics at all + (ExplainAnalyzeCategories::Only(vec![]), "output_rows", false), + ( + ExplainAnalyzeCategories::Only(vec![]), + "elapsed_compute", + false, + ), + // "all" — everything shown + (ExplainAnalyzeCategories::All, "output_rows", true), + (ExplainAnalyzeCategories::All, "elapsed_compute", true), + (ExplainAnalyzeCategories::All, "output_bytes", true), + // "rows,bytes" — row + byte metrics, no timing + ( + ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + ]), + "output_rows", + true, + ), + ( + ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + ]), + "output_bytes", + true, + ), + ( + ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + ]), + "elapsed_compute", + false, + ), + // "rows,bytes,uncategorized" — everything except timing + ( + ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + MetricCategory::Uncategorized, + ]), + "output_rows", + true, + ), + ( + ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + MetricCategory::Uncategorized, + ]), + "elapsed_compute", + false, + ), + ] { + let plan = collect_plan_with_categories(sql, categories.clone()).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for categories {categories:?} should{} contain '{needle}':\n{plan}", + if should_contain { "" } else { " NOT" } + ); + } +} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 7a59834475920..7c0e89ee96418 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -38,14 +38,16 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -72,14 +74,10 @@ async fn join_change_in_planner() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] " ); Ok(()) @@ -99,14 +97,16 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -133,14 +133,10 @@ async fn join_no_order_on_filter() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] " ); Ok(()) @@ -176,14 +172,10 @@ async fn join_change_in_planner_without_sort() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true " ); Ok(()) @@ -214,7 +206,10 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.strip_backtrace(), "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + assert_eq!( + e.strip_backtrace(), + "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag" + ) } } Ok(()) @@ -295,16 +290,12 @@ async fn unparse_cross_join() -> Result<()> { .await?; let unopt_sql = plan_to_sql(df.logical_plan())?; - assert_snapshot!(unopt_sql, @r#" - SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) - "#); + assert_snapshot!(unopt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); let optimized_plan = df.into_optimized_plan()?; let opt_sql = plan_to_sql(&optimized_plan)?; - assert_snapshot!(opt_sql, @r#" - SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) - "#); + assert_snapshot!(opt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); Ok(()) } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 743c8750b5215..9a1dc5502ee60 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -24,10 +24,10 @@ use arrow::{ use datafusion::error::Result; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlanVisitor; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::*; use datafusion::test_util; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; @@ -40,18 +40,24 @@ use std::io::Write; use std::path::PathBuf; use tempfile::TempDir; -/// A macro to assert that some particular line contains two substrings +/// A macro to assert that some particular line contains the given substrings /// -/// Usage: `assert_metrics!(actual, operator_name, metrics)` +/// Usage: `assert_metrics!(actual, operator_name, metrics_1, metrics_2, ...)` macro_rules! assert_metrics { - ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { + ($ACTUAL: expr, $OPERATOR_NAME: expr, $($METRICS: expr),+) => { let found = $ACTUAL .lines() - .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); + .any(|line| line.contains($OPERATOR_NAME) $( && line.contains($METRICS))+); + + let mut metrics = String::new(); + $(metrics.push_str(format!(" '{}',", $METRICS).as_str());)+ + // remove the last `,` from the string + metrics.pop(); + assert!( found, - "Can not find a line with both '{}' and '{}' in\n\n{}", - $OPERATOR_NAME, $METRICS, $ACTUAL + "Cannot find a line with operator name '{}' and metrics containing values {} in :\n\n{}", + $OPERATOR_NAME, metrics, $ACTUAL ); }; } @@ -64,6 +70,7 @@ mod path_partition; mod runtime_config; pub mod select; mod sql_api; +mod unparser; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = test_util::arrow_test_data(); @@ -329,8 +336,7 @@ async fn nyc() -> Result<()> { match &optimized_plan { LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. + projected_schema, .. }) => { assert_eq!(2, projected_schema.fields().len()); assert_eq!(projected_schema.field(0).name(), "passenger_count"); diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 05cc723ef05fb..2eff1c262f855 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -20,7 +20,6 @@ use std::collections::BTreeSet; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use arrow::datatypes::DataType; @@ -31,26 +30,28 @@ use datafusion::{ listing::{ListingOptions, ListingTable, ListingTableConfig}, }, error::Result, - physical_plan::ColumnStatistics, prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; use datafusion_catalog::TableProvider; +use datafusion_common::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; -use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; use async_trait::async_trait; use bytes::Bytes; use chrono::{TimeZone, Utc}; +use futures::StreamExt; use futures::stream::{self, BoxStream}; use insta::assert_snapshot; use object_store::{ - path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, - ObjectStore, PutOptions, PutResult, + Attributes, CopyOptions, GetRange, MultipartUpload, PutMultipartOptions, PutPayload, +}; +use object_store::{ + GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutOptions, PutResult, path::Path, }; -use object_store::{Attributes, MultipartUpload, PutMultipartOptions, PutPayload}; use url::Url; #[tokio::test] @@ -460,14 +461,26 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; + let stat_cols = physical_plan + .partition_statistics(None)? + .column_statistics + .clone(); assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); - // TODO assert partition column (1,2,3) stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[2], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[3], ColumnStatistics::new_unknown(),); + // Partition column statistics (year=2021 for all 3 rows) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + // month and day are Utf8 partition columns with statistics + assert_eq!(stat_cols[2].null_count, Precision::Exact(0)); + assert_eq!(stat_cols[3].null_count, Precision::Exact(0)); //// WITH PROJECTION //// let dataframe = ctx.sql("SELECT mycol, day FROM t WHERE day='28'").await?; @@ -475,12 +488,23 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; + let stat_cols = physical_plan + .partition_statistics(None)? + .column_statistics + .clone(); assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); - // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown()); + // Partition column statistics for day='28' (1 row) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); Ok(()) } @@ -604,7 +628,7 @@ async fn create_partitioned_alltypes_parquet_table( } #[derive(Debug)] -/// An object store implem that is mirrors a given file to multiple paths. +/// An object store implem that mirrors a given file to multiple paths. pub struct MirroringObjectStore { /// The `(path,size)` of the files that "exist" in the store files: Vec, @@ -653,12 +677,13 @@ impl ObjectStore for MirroringObjectStore { async fn get_opts( &self, location: &Path, - _options: GetOptions, + options: GetOptions, ) -> object_store::Result { self.files.iter().find(|x| *x == location).unwrap(); let path = std::path::PathBuf::from(&self.mirrored_file); let file = File::open(&path).unwrap(); let metadata = file.metadata().unwrap(); + let meta = ObjectMeta { location: location.clone(), last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), @@ -667,37 +692,35 @@ impl ObjectStore for MirroringObjectStore { version: None, }; + let payload = if options.head { + // no content for head requests + GetResultPayload::Stream(stream::empty().boxed()) + } else if let Some(range) = options.range { + let GetRange::Bounded(range) = range else { + unimplemented!("Unbounded range not supported in MirroringObjectStore"); + }; + let mut file = File::open(path).unwrap(); + file.seek(SeekFrom::Start(range.start)).unwrap(); + + let to_read = range.end - range.start; + let to_read: usize = to_read.try_into().unwrap(); + let mut data = Vec::with_capacity(to_read); + let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); + assert_eq!(read, to_read); + let stream = stream::once(async move { Ok(Bytes::from(data)) }).boxed(); + GetResultPayload::Stream(stream) + } else { + GetResultPayload::File(file, path) + }; + Ok(GetResult { range: 0..meta.size, - payload: GetResultPayload::File(file, path), + payload, meta, attributes: Attributes::default(), }) } - async fn get_range( - &self, - location: &Path, - range: Range, - ) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - let path = std::path::PathBuf::from(&self.mirrored_file); - let mut file = File::open(path).unwrap(); - file.seek(SeekFrom::Start(range.start)).unwrap(); - - let to_read = range.end - range.start; - let to_read: usize = to_read.try_into().unwrap(); - let mut data = Vec::with_capacity(to_read); - let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); - assert_eq!(read, to_read); - - Ok(data.into()) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, prefix: Option<&Path>, @@ -712,6 +735,8 @@ impl ObjectStore for MirroringObjectStore { .map(|mut x| x.next().is_some()) .unwrap_or(false); + #[expect(clippy::result_large_err)] + // closure only ever returns Ok; Err type is never constructed filter.then(|| { Ok(ObjectMeta { location, @@ -749,7 +774,7 @@ impl ObjectStore for MirroringObjectStore { }; if parts.next().is_some() { - common_prefixes.insert(prefix.child(common_prefix)); + common_prefixes.insert(prefix.clone().join(common_prefix)); } else { let object = ObjectMeta { location: k.clone(), @@ -767,14 +792,18 @@ impl ObjectStore for MirroringObjectStore { }) } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { unimplemented!() } diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs index 9627d7bccdb04..cf5237d725805 100644 --- a/datafusion/core/tests/sql/runtime_config.rs +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -18,9 +18,14 @@ //! Tests for runtime configuration SQL interface use std::sync::Arc; +use std::time::Duration; use datafusion::execution::context::SessionContext; use datafusion::execution::context::TaskContext; +use datafusion::prelude::SessionConfig; +use datafusion_execution::cache::DefaultListFilesCache; +use datafusion_execution::cache::cache_manager::CacheManagerConfig; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_physical_plan::common::collect; #[tokio::test] @@ -140,7 +145,7 @@ async fn test_memory_limit_enforcement() { } #[tokio::test] -async fn test_invalid_memory_limit() { +async fn test_invalid_memory_limit_when_unit_is_invalid() { let ctx = SessionContext::new(); let result = ctx @@ -149,7 +154,26 @@ async fn test_invalid_memory_limit() { assert!(result.is_err()); let error_message = result.unwrap_err().to_string(); - assert!(error_message.contains("Unsupported unit 'X'")); + assert!( + error_message + .contains("Unsupported unit 'X' in 'datafusion.runtime.memory_limit'") + && error_message.contains("Unit must be one of: 'K', 'M', 'G'") + ); +} + +#[tokio::test] +async fn test_invalid_memory_limit_when_limit_is_not_numeric() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = 'invalid_memory_limit'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains( + "Failed to parse number from 'datafusion.runtime.memory_limit', limit 'invalid_memory_limit'" + )); } #[tokio::test] @@ -233,6 +257,93 @@ async fn test_test_metadata_cache_limit() { assert_eq!(get_limit(&ctx), 123 * 1024); } +#[tokio::test] +async fn test_list_files_cache_limit() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_limit = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> usize { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_limit() + }; + + update_limit(&ctx, "100M").await; + assert_eq!(get_limit(&ctx), 100 * 1024 * 1024); + + update_limit(&ctx, "2G").await; + assert_eq!(get_limit(&ctx), 2 * 1024 * 1024 * 1024); + + update_limit(&ctx, "123K").await; + assert_eq!(get_limit(&ctx), 123 * 1024); +} + +#[tokio::test] +async fn test_list_files_cache_ttl() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_ttl = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> Duration { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_ttl() + .unwrap() + }; + + update_limit(&ctx, "1m").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(60)); + + update_limit(&ctx, "30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(30)); + + update_limit(&ctx, "1m30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(90)); +} + #[tokio::test] async fn test_unknown_runtime_config() { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 8a0f620627384..96b911e8db130 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -18,8 +18,7 @@ use std::collections::HashMap; use super::*; -use datafusion::assert_batches_eq; -use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue}; +use datafusion_common::{ParamValues, ScalarValue, metadata::ScalarAndMetadata}; use insta::assert_snapshot; #[tokio::test] @@ -223,10 +222,10 @@ async fn test_parameter_invalid_types() -> Result<()> { .await; assert_snapshot!(results.unwrap_err().strip_backtrace(), @r" - type_coercion - caused by - Error during planning: Cannot infer common argument type for comparison operation List(nullable Int32) = Int32 - "); + type_coercion + caused by + Error during planning: Cannot infer common argument type for comparison operation List(Int32) = Int32 + "); Ok(()) } @@ -343,26 +342,20 @@ async fn test_query_parameters_with_metadata() -> Result<()> { ])) .unwrap(); - // df_with_params_replaced.schema() is not correct here - // https://github.com/apache/datafusion/issues/18102 - let batches = df_with_params_replaced.clone().collect().await.unwrap(); - let schema = batches[0].schema(); - + let schema = df_with_params_replaced.schema(); assert_eq!(schema.field(0).data_type(), &DataType::UInt32); assert_eq!(schema.field(0).metadata(), &metadata1); assert_eq!(schema.field(1).data_type(), &DataType::Utf8); assert_eq!(schema.field(1).metadata(), &metadata2); - assert_batches_eq!( - [ - "+----+-----+", - "| $1 | $2 |", - "+----+-----+", - "| 1 | two |", - "+----+-----+", - ], - &batches - ); + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-----+ + | $1 | $2 | + +----+-----+ + | 1 | two | + +----+-----+ + "); Ok(()) } @@ -421,3 +414,82 @@ async fn test_select_no_projection() -> Result<()> { "); Ok(()) } + +#[tokio::test] +async fn test_select_cast_date_literal_to_timestamp_overflow() -> Result<()> { + let ctx = SessionContext::new(); + let err = ctx + .sql("SELECT CAST(DATE '9999-12-31' AS TIMESTAMP)") + .await? + .collect() + .await + .unwrap_err(); + + assert_contains!( + err.to_string(), + "Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range" + ); + Ok(()) +} + +// Regression test: a recursive CTE whose anchor aliases a computed column +// (`upper(val) AS val`) and whose recursive term leaves the same expression +// un-aliased must still produce batches whose schema field names come from +// the anchor term — especially when the outer query uses ORDER BY + LIMIT +// (the TopK path passes batch schemas through verbatim, so any drift in +// RecursiveQueryExec's emitted batches is visible downstream). +#[tokio::test] +async fn test_recursive_cte_batch_schema_stable_with_order_by_limit() -> Result<()> { + let ctx = SessionContext::new(); + ctx.sql( + "CREATE TABLE records (\ + id VARCHAR NOT NULL, \ + parent_id VARCHAR, \ + ts TIMESTAMP NOT NULL, \ + val VARCHAR\ + )", + ) + .await? + .collect() + .await?; + ctx.sql( + "INSERT INTO records VALUES \ + ('a00', NULL, TIMESTAMP '2025-01-01 00:00:00', 'v_span'), \ + ('a01', 'a00', TIMESTAMP '2025-01-01 00:00:01', 'v_log_1'), \ + ('a02', 'a01', TIMESTAMP '2025-01-01 00:00:02', 'v_log_2'), \ + ('a03', 'a02', TIMESTAMP '2025-01-01 00:00:03', 'v_log_3'), \ + ('a04', 'a03', TIMESTAMP '2025-01-01 00:00:04', 'v_log_4'), \ + ('a05', 'a04', TIMESTAMP '2025-01-01 00:00:05', 'v_log_5')", + ) + .await? + .collect() + .await?; + + let results = ctx + .sql( + "WITH RECURSIVE descendants AS (\ + SELECT id, parent_id, ts, upper(val) AS val \ + FROM records WHERE id = 'a00' \ + UNION ALL \ + SELECT r.id, r.parent_id, r.ts, upper(r.val) \ + FROM records r INNER JOIN descendants d ON r.parent_id = d.id \ + ) \ + SELECT id, parent_id, ts, val FROM descendants ORDER BY ts ASC LIMIT 10", + ) + .await? + .collect() + .await?; + + let expected_names = ["id", "parent_id", "ts", "val"]; + assert!(!results.is_empty(), "expected at least one batch"); + for (i, batch) in results.iter().enumerate() { + let schema = batch.schema(); + let actual_names: Vec<&str> = + schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert_eq!( + actual_names, expected_names, + "batch {i} schema field names leaked from recursive branch" + ); + } + Ok(()) +} diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index b87afd27ddea7..290aa737d2742 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::prelude::*; +use datafusion_common::assert_contains; use tempfile::TempDir; @@ -206,3 +207,19 @@ async fn ddl_can_not_be_planned_by_session_state() { "This feature is not implemented: Unsupported logical plan: DropTable" ); } + +#[tokio::test] +async fn invalid_wrapped_negation_fails_during_optimization() { + let ctx = SessionContext::new(); + let err = ctx + .sql("SELECT * FROM (SELECT 1) WHERE ((-'a') IS NULL)") + .await + .unwrap() + .into_optimized_plan() + .unwrap_err(); + + assert_contains!( + err.strip_backtrace(), + "Negation only supports numeric, interval and timestamp types" + ); +} diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs new file mode 100644 index 0000000000000..d6ca872e198c3 --- /dev/null +++ b/datafusion/core/tests/sql/unparser.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Unparser Roundtrip Integration Tests +//! +//! This module tests the [`Unparser`] by running queries through a complete roundtrip: +//! the original SQL is parsed into a logical plan, unparsed back to SQL, then that +//! generated SQL is parsed and executed. The results are compared to verify semantic +//! equivalence. +//! +//! ## Test Strategy +//! +//! Uses real-world benchmark queries (TPC-H and Clickbench) to validate that: +//! 1. The unparser produces syntactically valid SQL +//! 2. The unparsed SQL is semantically equivalent (produces identical results) +//! +//! ## Query Suites +//! +//! - **TPC-H**: Standard decision-support benchmark with 22 complex analytical queries +//! - **Clickbench**: Web analytics benchmark with 43 queries against a denormalized schema +//! +//! [`Unparser`]: datafusion_sql::unparser::Unparser + +use std::fs::ReadDir; +use std::future::Future; + +use arrow::array::RecordBatch; +use datafusion::common::Result; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::Column; +use datafusion_expr::Expr; +use datafusion_sql::unparser::Unparser; +use datafusion_sql::unparser::dialect::DefaultDialect; +use itertools::Itertools; +use recursive::{set_minimum_stack_size, set_stack_allocation_size}; + +/// Paths to benchmark query files (supports running from repo root or different working directories). +const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"]; + +/// Reads all `.sql` files from a directory and converts them to test queries. +/// +/// Skips files that: +/// - Are not regular files +/// - Don't have a `.sql` extension +/// - Contain multiple SQL statements (indicated by `;\n`) +/// +/// Multi-statement files are skipped because the unparser doesn't support +/// DML statements like `CREATE VIEW` that appear in multi-statement Clickbench queries. +fn iterate_queries(dir: ReadDir) -> Vec { + let mut queries = vec![]; + for entry in dir.flatten() { + let Ok(file_type) = entry.file_type() else { + continue; + }; + if !file_type.is_file() { + continue; + } + let path = entry.path(); + let Some(ext) = path.extension() else { + continue; + }; + if ext != "sql" { + continue; + } + let name = path.file_stem().unwrap().to_string_lossy().to_string(); + if let Ok(mut contents) = std::fs::read_to_string(entry.path()) { + // If the query contains ;\n it has DML statements like CREATE VIEW which the unparser doesn't support; skip it + contents = contents.trim().to_string(); + if contents.contains(";\n") { + println!("Skipping query with multiple statements: {name}"); + continue; + } + queries.push(TestQuery { + sql: contents, + name, + }); + } + } + queries +} + +/// A SQL query loaded from a benchmark file for roundtrip testing. +/// +/// Each query is identified by its filename (without extension) and contains +/// the full SQL text to be tested. +struct TestQuery { + /// The SQL query text to test. + sql: String, + /// The query identifier (typically the filename without .sql extension). + name: String, +} + +/// Collect SQL for Clickbench queries. +fn clickbench_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/clickbench/queries/"); + println!("Reading Clickbench queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + println!("Found {} Clickbench queries", read.len()); + queries.extend(read); + } + } + queries.sort_unstable_by_key(|q| { + q.name + .split('q') + .next_back() + .and_then(|num| num.parse::().ok()) + }); + queries +} + +/// Collect SQL for TPC-H queries. +fn tpch_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/"); + println!("Reading TPC-H queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + queries.extend(read); + } + } + println!("Total TPC-H queries found: {}", queries.len()); + queries.sort_unstable_by_key(|q| q.name.clone()); + queries +} + +/// Create a new SessionContext for testing that has all Clickbench tables registered. +/// +/// Registers the raw Parquet as `hits_raw`, then creates a `hits` view that +/// casts `EventDate` from UInt16 (day-offset) to DATE. This mirrors the +/// approach used by the benchmark runner in `benchmarks/src/clickbench.rs`. +async fn clickbench_test_context() -> Result { + let ctx = SessionContext::new(); + ctx.register_parquet( + "hits_raw", + "tests/data/clickbench_hits_10.parquet", + ParquetReadOptions::default(), + ) + .await?; + ctx.sql( + r#"CREATE VIEW hits AS + SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" + FROM hits_raw"#, + ) + .await?; + // Sanity check we found the table by querying its schema + let df = ctx.sql("SELECT * FROM hits LIMIT 1").await?; + assert!( + !df.schema().fields().is_empty(), + "Clickbench 'hits' table not registered correctly" + ); + Ok(ctx) +} + +/// Create a new SessionContext for testing that has all TPC-H tables registered. +async fn tpch_test_context() -> Result { + let ctx = SessionContext::new(); + let data_dir = "tests/data/"; + // All tables have the pattern "tpch__small.parquet" + for table in [ + "customer", "lineitem", "nation", "orders", "part", "partsupp", "region", + "supplier", + ] { + let path = format!("{data_dir}tpch_{table}_small.parquet"); + ctx.register_parquet(table, &path, ParquetReadOptions::default()) + .await?; + // Sanity check we found the table by querying it's schema, it should not be empty + // Otherwise if the path is wrong the tests will all fail in confusing ways + let df = ctx.sql(&format!("SELECT * FROM {table} LIMIT 1")).await?; + assert!( + !df.schema().fields().is_empty(), + "TPC-H '{table}' table not registered correctly" + ); + } + Ok(ctx) +} + +/// Sorts record batches by all columns for deterministic comparison. +/// +/// When comparing query results, we need a canonical ordering so that +/// semantically equivalent results compare as equal. This function sorts +/// by all columns in the schema to achieve that. +async fn sort_batches( + ctx: &SessionContext, + batches: Vec, +) -> Result> { + let mut df = ctx.read_batches(batches)?; + let schema = df.schema().as_arrow().clone(); + let sort_exprs = schema + .fields() + .iter() + // Use Column directly, col() causes the column names to be normalized to lowercase + .map(|f| { + Expr::Column(Column::new_unqualified(f.name().to_string())).sort(true, false) + }) + .collect_vec(); + if !sort_exprs.is_empty() { + df = df.sort(sort_exprs)?; + } + df.collect().await +} + +/// The outcome of running a single roundtrip test. +/// +/// A successful test produces [`TestCaseResult::Success`]. +/// All other variants capture different failure modes with enough context to diagnose the issue. +enum TestCaseResult { + /// The unparsed SQL produced identical results to the original. + Success, + + /// Both queries executed but produced different results. + /// + /// This indicates a semantic bug in the unparser where the generated SQL + /// has different meaning than the original. + ResultsMismatch { original: String, unparsed: String }, + + /// The unparser failed to convert the logical plan to SQL. + /// + /// This may indicate an unsupported SQL feature or a bug in the unparser. + UnparseError { original: String, error: String }, + + /// The original SQL failed to execute. + /// + /// This indicates a problem with the test setup (missing tables, + /// invalid test data) rather than an unparser issue. + ExecutionError { original: String, error: String }, + + /// The unparsed SQL failed to execute, even though the original succeeded. + /// + /// This indicates the unparser generated syntactically invalid SQL or SQL + /// that references non-existent columns/tables. + UnparsedExecutionError { + original: String, + unparsed: String, + error: String, + }, +} + +impl TestCaseResult { + /// Returns true if the test case represents a failure + /// (anything other than [`TestCaseResult::Success`]). + fn is_failure(&self) -> bool { + !matches!(self, TestCaseResult::Success) + } + + /// Formats a detailed error message for the test case into a string. + fn format_error(&self, name: &str) -> String { + match self { + TestCaseResult::Success => String::new(), + TestCaseResult::ResultsMismatch { original, unparsed } => { + format!( + "Results mismatch for {name}.\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + TestCaseResult::UnparseError { original, error } => { + format!("Unparse error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::ExecutionError { original, error } => { + format!("Execution error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::UnparsedExecutionError { + original, + unparsed, + error, + } => { + format!( + "Unparsed execution error for {name}: {error}\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + } + } +} + +/// Executes a roundtrip test for a single SQL query. +/// +/// This is the core test logic that: +/// 1. Parses the original SQL and creates a logical plan +/// 2. Unparses the logical plan back to SQL +/// 3. Executes both the original and unparsed queries +/// 4. Compares the results (sorting if the query has no ORDER BY) +/// +/// This always uses [`DefaultDialect`] for unparsing. +/// +/// # Arguments +/// +/// * `ctx` - Session context with tables registered +/// * `original` - The original SQL query to test +/// +/// # Returns +/// +/// A [`TestCaseResult`] indicating success or the specific failure mode. +async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult { + let unparser = Unparser::new(&DefaultDialect {}); + + // Parse and create logical plan from original SQL + let df = match ctx.sql(original).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Unparse the logical plan back to SQL + let unparsed = match unparser.plan_to_sql(df.logical_plan()) { + Ok(sql) => format!("{sql:#}"), + Err(e) => { + return TestCaseResult::UnparseError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Collect results from original query + let mut expected = match df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Parse and execute the unparsed SQL + let actual_df = match ctx.sql(&unparsed).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Collect results from unparsed query + let mut actual = match actual_df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Always sort for deterministic comparison — even "sorted" results can have + // tied rows in different order between original and unparsed SQL. + { + expected = match sort_batches(ctx, expected).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: format!("Failed to sort expected results: {e}"), + }; + } + }; + actual = match sort_batches(ctx, actual).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: format!("Failed to sort actual results: {e}"), + }; + } + }; + } + + if expected != actual { + TestCaseResult::ResultsMismatch { + original: original.to_string(), + unparsed, + } + } else { + TestCaseResult::Success + } +} + +/// Runs roundtrip tests for a collection of queries and reports results. +/// +/// Iterates through all queries, running each through [`collect_results`]. +/// Prints colored status (green checkmark for success, red X for failure) +/// and panics at the end if any tests failed, with detailed error messages. +/// +/// # Type Parameters +/// +/// * `F` - Factory function that creates fresh session contexts +/// * `Fut` - Future type returned by the context factory +/// +/// # Panics +/// +/// Panics if any query fails the roundtrip test, displaying all failures. +async fn run_roundtrip_tests( + suite_name: &str, + queries: Vec, + create_context: F, +) where + F: Fn() -> Fut, + Fut: Future>, +{ + let mut errors: Vec = vec![]; + for sql in queries { + let ctx = match create_context().await { + Ok(ctx) => ctx, + Err(e) => { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(format!("Failed to create context for {}: {}", sql.name, e)); + continue; + } + }; + let result = collect_results(&ctx, &sql.sql).await; + if result.is_failure() { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(result.format_error(&sql.name)); + } else { + println!("\x1b[32m✓\x1b[0m {} query: {}", suite_name, sql.name); + } + } + if !errors.is_empty() { + panic!( + "{} {} test(s) failed:\n\n{}", + errors.len(), + suite_name, + errors.join("\n\n---\n\n") + ); + } +} + +#[tokio::test] +async fn test_clickbench_unparser_roundtrip() { + run_roundtrip_tests("Clickbench", clickbench_queries(), clickbench_test_context) + .await; +} + +#[tokio::test] +async fn test_tpch_unparser_roundtrip() { + // Grow stacker segments earlier to avoid deep unparser recursion overflow in q20. + set_minimum_stack_size(512 * 1024); + set_stack_allocation_size(8 * 1024 * 1024); + run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await; +} diff --git a/datafusion/core/tests/tpc-ds/30.sql b/datafusion/core/tests/tpc-ds/30.sql index 78f34b807e5b5..80624f49006a9 100644 --- a/datafusion/core/tests/tpc-ds/30.sql +++ b/datafusion/core/tests/tpc-ds/30.sql @@ -14,7 +14,7 @@ with customer_total_return as ,ca_state) select c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return from customer_total_return ctr1 ,customer_address ,customer @@ -26,7 +26,7 @@ with customer_total_return as and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return limit 100; diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d92..3ad74962bc2c0 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/core/tests/tracing/asserting_tracer.rs b/datafusion/core/tests/tracing/asserting_tracer.rs index 292e066e5f121..700f9f3308466 100644 --- a/datafusion/core/tests/tracing/asserting_tracer.rs +++ b/datafusion/core/tests/tracing/asserting_tracer.rs @@ -21,7 +21,7 @@ use std::ops::Deref; use std::sync::{Arc, LazyLock}; use datafusion_common::{HashMap, HashSet}; -use datafusion_common_runtime::{set_join_set_tracer, JoinSetTracer}; +use datafusion_common_runtime::{JoinSetTracer, set_join_set_tracer}; use futures::future::BoxFuture; use tokio::sync::{Mutex, MutexGuard}; diff --git a/datafusion/core/tests/tracing/traceable_object_store.rs b/datafusion/core/tests/tracing/traceable_object_store.rs index 60ef1cc5d6b6a..71a61dbf8772a 100644 --- a/datafusion/core/tests/tracing/traceable_object_store.rs +++ b/datafusion/core/tests/tracing/traceable_object_store.rs @@ -18,10 +18,11 @@ //! Object store implementation used for testing use crate::tracing::asserting_tracer::assert_traceability; +use futures::StreamExt; use futures::stream::BoxStream; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -83,14 +84,17 @@ impl ObjectStore for TraceableObjectStore { self.inner.get_opts(location, options).await } - async fn head(&self, location: &Path) -> object_store::Result { - assert_traceability().await; - self.inner.head(location).await - } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner + .delete_stream(locations) + .then(|res| async { + futures::executor::block_on(assert_traceability()); + res + }) + .boxed() } fn list( @@ -109,17 +113,13 @@ impl ObjectStore for TraceableObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { assert_traceability().await; - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 07d289cab06c2..c5e5af731359f 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -26,9 +26,9 @@ use datafusion::logical_expr::Operator; use datafusion::prelude::*; use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion_common::ScalarValue; +use datafusion_expr::BinaryExpr; use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; -use datafusion_expr::BinaryExpr; #[derive(Debug)] struct MyCustomPlanner; @@ -77,25 +77,25 @@ async fn plan_and_collect(sql: &str) -> Result> { #[tokio::test] async fn test_custom_operators_arrow() { let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +----------------------------+ | Utf8("foo") || Utf8("bar") | +----------------------------+ | foobar | +----------------------------+ - "###); + "#); } #[tokio::test] async fn test_custom_operators_long_arrow() { let actual = plan_and_collect("select 1->>2;").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---------------------+ | Int64(1) + Int64(2) | +---------------------+ | 3 | +---------------------+ - "###); + "); } #[tokio::test] @@ -103,13 +103,13 @@ async fn test_question_select() { let actual = plan_and_collect("select a ? 2 from (select 1 as a);") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +--------------+ | a ? Int64(2) | +--------------+ | true | +--------------+ - "###); + "); } #[tokio::test] @@ -117,11 +117,11 @@ async fn test_question_filter() { let actual = plan_and_collect("select a from (select 1 as a) where a ? 2;") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index e0a3e98604ae4..326c767d97610 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -25,12 +25,13 @@ use datafusion::{ }; use datafusion_catalog::{Session, TableProvider}; use datafusion_common::config::Dialect; -use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_expr::{Expr, TableType, dml::InsertOp}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::SchedulingType; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; #[tokio::test] @@ -57,7 +58,7 @@ async fn insert_operation_is_passed_correctly_to_table_provider() { async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) { let df = ctx.sql(sql).await.unwrap(); let plan = df.create_physical_plan().await.unwrap(); - let exec = plan.as_any().downcast_ref::().unwrap(); + let exec = plan.downcast_ref::().unwrap(); assert_eq!(exec.op, insert_op); } @@ -87,10 +88,6 @@ impl TestInsertTableProvider { #[async_trait] impl TableProvider for TestInsertTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -122,20 +119,22 @@ impl TableProvider for TestInsertTableProvider { #[derive(Debug)] struct TestInsertExec { op: InsertOp, - plan_properties: PlanProperties, + plan_properties: Arc, } impl TestInsertExec { fn new(op: InsertOp) -> Self { Self { op, - plan_properties: PlanProperties::new( - EquivalenceProperties::new(make_count_schema()), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, - ) - .with_scheduling_type(SchedulingType::Cooperative), + plan_properties: Arc::new( + PlanProperties::new( + EquivalenceProperties::new(make_count_schema()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(SchedulingType::Cooperative), + ), } } } @@ -155,11 +154,7 @@ impl ExecutionPlan for TestInsertExec { "TestInsertExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.plan_properties } @@ -182,6 +177,22 @@ impl ExecutionPlan for TestInsertExec { ) -> Result { unimplemented!("TestInsertExec is a stub for testing.") } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion_physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.plan_properties.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } fn make_count_schema() -> SchemaRef { diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 5d84cdb692830..bc9949f5d681c 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// Tests for user defined Async Scalar functions +mod user_defined_async_scalar_functions; + /// Tests for user defined Scalar functions mod user_defined_scalar_functions; @@ -33,5 +36,8 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; +/// Tests for Relation Planner extensions +mod relation_planner; + /// Tests for insert operations mod insert_operation; diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs new file mode 100644 index 0000000000000..54af53ad858d4 --- /dev/null +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the RelationPlanner extension point + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::memory::MemTable; +use datafusion::common::test_util::batches_to_string; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Expr; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Test Planners - Example Implementations +// ============================================================================ + +// The planners in this section are deliberately minimal, static examples used +// only for tests. In real applications a `RelationPlanner` would typically +// construct richer logical plans tailored to external systems or custom +// semantics rather than hard-coded in-memory tables. +// +// For more realistic examples, see `datafusion-examples/examples/relation_planner/`: +// - `table_sample.rs`: Full TABLESAMPLE implementation (parsing → execution) +// - `pivot_unpivot.rs`: PIVOT/UNPIVOT via SQL rewriting +// - `match_recognize.rs`: MATCH_RECOGNIZE logical planning + +/// Helper to build simple static values-backed virtual tables used by the +/// example planners below. +fn plan_static_values_table( + relation: TableFactor, + table_name: &str, + column_name: &str, + values: Vec, +) -> Result { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case(table_name) => + { + let rows = values + .into_iter() + .map(|v| vec![Expr::Literal(v, None)]) + .collect::>(); + + let plan = LogicalPlanBuilder::values(rows)? + .project(vec![col("column1").alias(column_name)])? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + other => Ok(RelationPlanning::Original(Box::new(other))), + } +} + +/// Example planner that provides a virtual `numbers` table with values +/// 1, 2, 3. +#[derive(Debug)] +struct NumbersPlanner; + +impl RelationPlanner for NumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)), + ScalarValue::Int64(Some(3)), + ], + ) + } +} + +/// Example planner that provides a virtual `colors` table with three string +/// values: `red`, `green`, `blue`. +#[derive(Debug)] +struct ColorsPlanner; + +impl RelationPlanner for ColorsPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "colors", + "color", + vec![ + ScalarValue::Utf8(Some("red".into())), + ScalarValue::Utf8(Some("green".into())), + ScalarValue::Utf8(Some("blue".into())), + ], + ) + } +} + +/// Alternative implementation of `numbers` (returns 100, 200) used to +/// demonstrate planner precedence (last registered planner wins). +#[derive(Debug)] +struct AlternativeNumbersPlanner; + +impl RelationPlanner for AlternativeNumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ScalarValue::Int64(Some(100)), ScalarValue::Int64(Some(200))], + ) + } +} + +/// Example planner that intercepts nested joins and samples both sides (limit 2) +/// before joining, demonstrating recursive planning with `context.plan()`. +#[derive(Debug)] +struct SamplingJoinPlanner; + +impl RelationPlanner for SamplingJoinPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::NestedJoin { + table_with_joins, + alias, + .. + } if table_with_joins.joins.len() == 1 => { + // Use context.plan() to recursively plan both sides + // This ensures other planners (like NumbersPlanner) can handle them + let left = context.plan(table_with_joins.relation.clone())?; + let right = context.plan(table_with_joins.joins[0].relation.clone())?; + + // Sample each table to 2 rows + let left_sampled = + LogicalPlanBuilder::from(left).limit(0, Some(2))?.build()?; + + let right_sampled = + LogicalPlanBuilder::from(right).limit(0, Some(2))?.build()?; + + // Cross join: 2 rows × 2 rows = 4 rows (instead of 3×3=9 without sampling) + let plan = LogicalPlanBuilder::from(left_sampled) + .cross_join(right_sampled)? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +/// Example planner that never handles any relation and always delegates by +/// returning `RelationPlanning::Original`. +#[derive(Debug)] +struct PassThroughPlanner; + +impl RelationPlanner for PassThroughPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + // Never handles anything - always delegates + Ok(RelationPlanning::Original(Box::new(relation))) + } +} + +/// Example planner that shows how planners can block specific constructs and +/// surface custom error messages by rejecting `UNNEST` relations (here framed +/// as a mock premium feature check). +#[derive(Debug)] +struct PremiumFeaturePlanner; + +impl RelationPlanner for PremiumFeaturePlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::UNNEST { .. } => Err(datafusion_common::DataFusionError::Plan( + "UNNEST is a premium feature! Please upgrade to DataFusion Pro™ \ + to unlock advanced array operations." + .to_string(), + )), + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +// ============================================================================ +// Test Helpers - SQL Execution +// ============================================================================ + +/// Execute SQL and return results with better error messages. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result> { + let df = ctx.sql(sql).await?; + df.collect().await +} + +/// Execute SQL and convert to string format for snapshot comparison. +async fn execute_sql_to_string(ctx: &SessionContext, sql: &str) -> String { + let batches = execute_sql(ctx, sql) + .await + .expect("SQL execution should succeed"); + batches_to_string(&batches) +} + +// ============================================================================ +// Test Helpers - Context Builders +// ============================================================================ + +/// Create a SessionContext with a catalog table containing Int64 and Utf8 columns. +/// +/// Creates a table with the specified name and sample data for fallback/integration tests. +fn create_context_with_catalog_table( + table_name: &str, + id_values: Vec, + name_values: Vec<&str>, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(id_values)), + Arc::new(StringArray::from(name_values)), + ], + ) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +/// Create a SessionContext with a simple single-column Int64 table. +/// +/// Useful for basic tests that need a real catalog table. +fn create_context_with_simple_table( + table_name: &str, + values: Vec, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))]) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +// ============================================================================ +// TESTS: Ordered from Basic to Complex +// ============================================================================ + +/// Comprehensive test suite for RelationPlanner extension point. +/// Tests are ordered from simplest smoke test to most complex scenarios. +#[cfg(test)] +mod tests { + use super::*; + + /// Small extension trait to make test setup read fluently. + trait TestSessionExt { + fn with_planner(self, planner: P) -> Self; + } + + impl TestSessionExt for SessionContext { + fn with_planner(self, planner: P) -> Self { + self.register_relation_planner(Arc::new(planner)).unwrap(); + self + } + } + + /// Session context with only the `NumbersPlanner` registered. + fn ctx_with_numbers() -> SessionContext { + SessionContext::new().with_planner(NumbersPlanner) + } + + /// Session context with virtual tables (`numbers`, `colors`) and the + /// `SamplingJoinPlanner` registered for nested joins. + fn ctx_with_virtual_tables_and_sampling() -> SessionContext { + SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner) + .with_planner(SamplingJoinPlanner) + } + + // Basic smoke test: virtual table can be queried like a regular table. + #[tokio::test] + async fn virtual_table_basic_select() { + let ctx = ctx_with_numbers(); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Virtual table supports standard SQL operations (projection, filter, aggregation). + #[tokio::test] + async fn virtual_table_filters_and_aggregation() { + let ctx = ctx_with_numbers(); + + let filtered = execute_sql_to_string( + &ctx, + "SELECT number * 10 AS scaled FROM numbers WHERE number > 1", + ) + .await; + + assert_snapshot!(filtered, @r" + +--------+ + | scaled | + +--------+ + | 20 | + | 30 | + +--------+ + "); + + let aggregated = execute_sql_to_string( + &ctx, + "SELECT COUNT(*) as count, SUM(number) as total, AVG(number) as average \ + FROM numbers", + ) + .await; + + assert_snapshot!(aggregated, @r" + +-------+-------+---------+ + | count | total | average | + +-------+-------+---------+ + | 3 | 6 | 2.0 | + +-------+-------+---------+ + "); + } + + // Multiple planners can coexist and each handles its own virtual table. + #[tokio::test] + async fn multiple_planners_virtual_tables() { + let ctx = SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner); + + let result1 = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + assert_snapshot!(result1, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + + let result2 = execute_sql_to_string(&ctx, "SELECT * FROM colors").await; + assert_snapshot!(result2, @r" + +-------+ + | color | + +-------+ + | red | + | green | + | blue | + +-------+ + "); + } + + // Last registered planner for the same table name takes precedence (LIFO). + #[tokio::test] + async fn lifo_precedence_last_planner_wins() { + let ctx = SessionContext::new() + .with_planner(AlternativeNumbersPlanner) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + // CustomValuesPlanner registered last, should win (returns 1,2,3 not 100,200) + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Pass-through planner delegates to the catalog without changing behavior. + #[tokio::test] + async fn delegation_pass_through_to_catalog() { + let ctx = create_context_with_simple_table("real_table", vec![42]) + .with_planner(PassThroughPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM real_table").await; + + assert_snapshot!(result, @r" + +-------+ + | value | + +-------+ + | 42 | + +-------+ + "); + } + + // Catalog is used when no planner claims the relation. + #[tokio::test] + async fn catalog_fallback_when_no_planner() { + let ctx = + create_context_with_catalog_table("users", vec![1, 2], vec!["Alice", "Bob"]) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM users ORDER BY id").await; + + assert_snapshot!(result, @r" + +----+-------+ + | id | name | + +----+-------+ + | 1 | Alice | + | 2 | Bob | + +----+-------+ + "); + } + + // Planners can block specific constructs and surface custom error messages. + #[tokio::test] + async fn error_handling_premium_feature_blocking() { + // Verify UNNEST works without planner + let ctx_without_planner = SessionContext::new(); + let result = + execute_sql(&ctx_without_planner, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect("UNNEST should work by default"); + assert_eq!(result.len(), 1); + + // Same query with blocking planner registered + let ctx = SessionContext::new().with_planner(PremiumFeaturePlanner); + + // Verify UNNEST is now rejected + let error = execute_sql(&ctx, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect_err("UNNEST should be rejected"); + + let error_msg = error.to_string(); + assert!( + error_msg.contains("premium feature") && error_msg.contains("DataFusion Pro"), + "Expected custom rejection message, got: {error_msg}" + ); + } + + // SamplingJoinPlanner recursively calls `context.plan()` on both sides of a + // nested join before sampling, exercising recursive relation planning. + #[tokio::test] + async fn recursive_planning_sampling_join() { + let ctx = ctx_with_virtual_tables_and_sampling(); + + let result = + execute_sql_to_string(&ctx, "SELECT * FROM (numbers JOIN colors ON true)") + .await; + + // SamplingJoinPlanner limits each side to 2 rows: 2×2=4 (not 3×3=9) + assert_snapshot!(result, @r" + +--------+-------+ + | number | color | + +--------+-------+ + | 1 | red | + | 1 | green | + | 2 | red | + | 2 | green | + +--------+-------+ + "); + } +} diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 62e8ab18b9be0..7d22c5df70dfc 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,18 +18,17 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use std::any::Any; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ - atomic::{AtomicBool, Ordering}, Arc, + atomic::{AtomicBool, Ordering}, }; use arrow::array::{ - record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, - StringArray, StructArray, UInt64Array, + Array, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, UInt64Array, + record_batch, types::UInt64Type, }; use arrow::datatypes::{Fields, Schema}; use arrow_schema::FieldRef; @@ -56,8 +55,8 @@ use datafusion_common::{cast::as_primitive_array, exec_err}; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, - GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, + AggregateUDFImpl, Expr, GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, + WindowFunctionDefinition, col, create_udaf, function::AccumulatorArgs, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -69,7 +68,7 @@ async fn test_setup() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-------+----------------------------+ | value | time | +-------+----------------------------+ @@ -79,7 +78,7 @@ async fn test_setup() { | 5.0 | 1970-01-01T00:00:00.000005 | | 5.0 | 1970-01-01T00:00:00.000005 | +-------+----------------------------+ - "###); + "); } /// Basic user defined aggregate @@ -91,13 +90,13 @@ async fn test_udaf() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // normal aggregates call update_batch assert!(test_state.update_batch()); @@ -112,7 +111,7 @@ async fn test_udaf_as_window() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -122,7 +121,7 @@ async fn test_udaf_as_window() { | 1970-01-01T00:00:00.000019 | | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // aggregate over the entire window function call update_batch assert!(test_state.update_batch()); @@ -137,7 +136,7 @@ async fn test_udaf_as_window_with_frame() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -147,7 +146,7 @@ async fn test_udaf_as_window_with_frame() { | 1970-01-01T00:00:00.000014 | | 1970-01-01T00:00:00.000010 | +----------------------------+ - "###); + "); // user defined aggregates with window frame should be calling retract batch assert!(test_state.update_batch()); @@ -164,7 +163,10 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); + assert_contains!( + err.to_string(), + "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING" + ); } /// Basic query for with a udaf returning a structure @@ -175,13 +177,13 @@ async fn test_udaf_returning_struct() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------------+ | first(t.value,t.time) | +------------------------------------------------+ | {value: 2.0, time: 1970-01-01T00:00:00.000002} | +------------------------------------------------+ - "###); + "); } /// Demonstrate extracting the fields from a structure using a subquery @@ -192,13 +194,13 @@ async fn test_udaf_returning_struct_subquery() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-----------------+----------------------------+ | sq.first[value] | sq.first[time] | +-----------------+----------------------------+ | 2.0 | 1970-01-01T00:00:00.000002 | +-----------------+----------------------------+ - "###); + "); } #[tokio::test] @@ -212,13 +214,13 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +---------------------------------------+ | sum(arrow_cast(t.time,Utf8("Int64"))) | +---------------------------------------+ | 19000 | +---------------------------------------+ - "###); + "#); // Register `TimeSum` with name `sum`. This will shadow the builtin one TimeSum::register(&mut ctx, test_state.clone(), "sum"); @@ -226,13 +228,13 @@ async fn test_udaf_shadows_builtin_fn() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); } async fn execute(ctx: &SessionContext, sql: &str) -> Result> { @@ -272,13 +274,13 @@ async fn simple_udaf() -> Result<()> { let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | my_avg(t.a) | +-------------+ | 3.0 | +-------------+ - "###); + "); Ok(()) } @@ -329,9 +331,10 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { // doesn't work as it was registered as non lowercase let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_avg\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_avg\'") + ); // Can call it if you put quotes let result = ctx @@ -340,13 +343,13 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | MY_AVG(t.i) | +-------------+ | 1.0 | +-------------+ - "###); + "); Ok(()) } @@ -372,13 +375,13 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1.0 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; @@ -449,13 +452,13 @@ async fn test_parameterized_aggregate_udf() -> Result<()> { let actual = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------+---+---+ | text | a | b | +------+---+---+ | foo | 1 | 2 | +------+---+---+ - "###); + "); ctx.deregister_table("t")?; Ok(()) @@ -569,6 +572,7 @@ impl TimeSum { Self { sum: 0, test_state } } + #[expect(clippy::needless_pass_by_value)] fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); let input_type = vec![timestamp_type.clone()]; @@ -760,11 +764,11 @@ impl Accumulator for FirstSelector { // Update the actual values for (value, time) in v.iter().zip(t.iter()) { - if let (Some(time), Some(value)) = (time, value) { - if time < self.time { - self.value = value; - self.time = time; - } + if let (Some(time), Some(value)) = (time, value) + && time < self.time + { + self.value = value; + self.time = time; } } @@ -788,10 +792,6 @@ struct TestGroupsAccumulator { } impl AggregateUDFImpl for TestGroupsAccumulator { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "geo_mean" } @@ -931,10 +931,6 @@ impl MetadataBasedAggregateUdf { } impl AggregateUDFImpl for MetadataBasedAggregateUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs new file mode 100644 index 0000000000000..58a5cb803982b --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::prelude::*; +use datafusion_common::test_util::format_batches; +use datafusion_common::{Result, assert_batches_eq}; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +fn register_table_and_udf() -> Result { + let num_rows = 3; + let batch_size = 2; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("prompt", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from((0..num_rows).collect::>())), + Arc::new(StringArray::from( + (0..num_rows) + .map(|i| format!("prompt{i}")) + .collect::>(), + )), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test_table", batch)?; + + ctx.register_udf( + AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(batch_size))) + .into_scalar_udf(), + ); + + Ok(ctx) +} + +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") + .await?; + + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+" + ], + &result + ); + + Ok(()) +} + +// This test checks if metrics are printed for `AsyncFuncExec` +#[tokio::test] +async fn test_async_udf_metrics() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql( + "EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table", + ) + .await?; + + let result = df.collect().await?; + + let explain_analyze_str = format_batches(&result)?.to_string(); + let async_func_exec_without_metrics = + explain_analyze_str.split("\n").any(|metric_line| { + metric_line.contains("AsyncFuncExec") + && !metric_line.contains("output_rows=3") + }); + + assert!(!async_func_exec_without_metrics); + + Ok(()) +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct TestAsyncUDFImpl { + batch_size: usize, + signature: Signature, +} + +impl TestAsyncUDFImpl { + fn new(batch_size: usize) -> Self { + Self { + batch_size, + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for TestAsyncUDFImpl { + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("Call invoke_async_with_args instead") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for TestAsyncUDFImpl { + fn ideal_batch_size(&self) -> Option { + Some(self.batch_size) + } + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let arg1 = &args.args[0]; + let results = call_external_service(arg1.clone()).await?; + Ok(results) + } +} + +/// Simulates calling an async external service +async fn call_external_service(arg1: ColumnarValue) -> Result { + Ok(arg1) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index ffe0ba021edb3..505468a19cd37 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -60,7 +60,7 @@ use std::fmt::Debug; use std::hash::Hash; use std::task::{Context, Poll}; -use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; +use std::{collections::BTreeMap, fmt, sync::Arc}; use arrow::array::{Array, ArrayRef, StringViewArray}; use arrow::{ @@ -70,7 +70,7 @@ use arrow::{ use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::as_int64_array, - common::{arrow_datafusion_err, internal_err, DFSchemaRef}, + common::{DFSchemaRef, arrow_datafusion_err}, error::{DataFusionError, Result}, execution::{ context::{QueryPlanner, SessionState, TaskContext}, @@ -84,17 +84,19 @@ use datafusion::{ physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, }, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{ScalarValue, assert_eq_or_internal_err, assert_or_internal_err}; use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr}; -use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; +use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; @@ -161,7 +163,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +-------------+---------+ | customer_id | revenue | +-------------+---------+ @@ -169,7 +171,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result | jorge | 200 | | andy | 150 | +-------------+---------+ - "###); + "); }); } @@ -188,13 +190,13 @@ async fn run_and_compare_query_with_analyzer_rule( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +------------+--------------------------+ | UInt64(42) | arrow_typeof(UInt64(42)) | +------------+--------------------------+ | 42 | UInt64 | +------------+--------------------------+ - "###); + "); }); Ok(()) @@ -212,7 +214,7 @@ async fn run_and_compare_query_with_auto_schemas( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +----------+----------+ | column_1 | column_2 | +----------+----------+ @@ -220,7 +222,7 @@ async fn run_and_compare_query_with_auto_schemas( | jorge | 200 | | andy | 150 | +----------+----------+ - "###); + "); }); Ok(()) @@ -433,21 +435,21 @@ impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - if let LogicalPlan::Extension(Extension { node }) = &plan { - if let Some(prev) = node.as_any().downcast_ref::() { - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: prev.k, - input: prev.input.clone(), - expr: prev.expr.clone(), - // In a real use case, this rewriter could have change the number of inputs, etc - invariant_mock: Some(InvariantMock { - should_fail_invariant: true, - kind: InvariantLevel::Always, - }), + if let LogicalPlan::Extension(Extension { node }) = &plan + && let Some(prev) = node.as_any().downcast_ref::() + { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, }), - }))); - } + }), + }))); }; Ok(Transformed::no(plan)) @@ -515,23 +517,18 @@ impl OptimizerRule for TopKOptimizerRule { return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = limit.input.as_ref() + if let LogicalPlan::Sort(Sort { expr, input, .. }) = limit.input.as_ref() + && expr.len() == 1 { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - invariant_mock: self.invariant_mock.clone(), - }), - }))); - } + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + invariant_mock: self.invariant_mock.clone(), + }), + }))); } Ok(Transformed::no(plan)) @@ -585,9 +582,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { kind, }) = self.invariant_mock.clone() { - if should_fail_invariant && check == kind { - return internal_err!("node fails check, such as improper inputs"); - } + assert_or_internal_err!( + !(should_fail_invariant && check == kind), + "node fails check, such as improper inputs" + ); } Ok(()) } @@ -657,13 +655,17 @@ struct TopKExec { input: Arc, /// The maximum number of values k: usize, - cache: PlanProperties, + cache: Arc, } impl TopKExec { fn new(input: Arc, k: usize) -> Self { let cache = Self::compute_properties(input.schema()); - Self { input, k, cache } + Self { + input, + k, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -704,11 +706,7 @@ impl ExecutionPlan for TopKExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -733,9 +731,11 @@ impl ExecutionPlan for TopKExec { partition: usize, context: Arc, ) -> Result { - if 0 != partition { - return internal_err!("TopKExec invalid partition {partition}"); - } + assert_eq_or_internal_err!( + partition, + 0, + "TopKExec invalid partition {partition}" + ); Ok(Box::pin(TopKReader { input: self.input.execute(partition, context)?, @@ -745,10 +745,20 @@ impl ExecutionPlan for TopKExec { })) } - fn statistics(&self) -> Result { - // to improve the optimizability of this plan - // better statistics inference could be provided - Ok(Statistics::new_unknown(&self.schema())) + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 3ca8f846aa5e5..b758aeb5209e8 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ - builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, - Int32Array, RecordBatch, StringArray, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, + builder::BooleanBuilder, cast::AsArray, }; +use arrow::array::{Int8Array, UInt64Array, as_string_array, create_array, record_batch}; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; @@ -38,15 +37,17 @@ use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, - exec_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, assert_batches_eq, + assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, + not_impl_err, plan_err, }; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, - LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, + OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, lit_with_metadata, }; +use datafusion_expr_common::signature::TypeSignature; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; use regex::Regex; @@ -63,13 +64,13 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c11)) | +------------------------------------------+ | 0.6584408483418835 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -82,13 +83,13 @@ async fn csv_query_avg_sqrt() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c12)) | +------------------------------------------+ | 0.6706002946036459 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -153,7 +154,7 @@ async fn scalar_udf() -> Result<()> { let result = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+-----+-----------------+ | a | b | my_add(t.a,t.b) | +-----+-----+-----------------+ @@ -162,7 +163,7 @@ async fn scalar_udf() -> Result<()> { | 10 | 12 | 22 | | 100 | 120 | 220 | +-----+-----+-----------------+ - "###); + "); let batch = &result[0]; let a = as_int32_array(batch.column(0))?; @@ -199,10 +200,6 @@ impl std::fmt::Debug for Simple0ArgsScalarUDF { } impl ScalarUDFImpl for Simple0ArgsScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -279,7 +276,7 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_udf(ScalarUDF::from(get_100_udf)); let result = plan_and_collect(&ctx, "select get_100() a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ @@ -288,22 +285,22 @@ async fn scalar_udf_zero_params() -> Result<()> { | 100 | | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() a").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() from t where a=999").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" ++ ++ - "###); + "); Ok(()) } @@ -330,13 +327,13 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { // Make sure that the UDF is used instead of the built-in function let result = plan_and_collect(&ctx, "select abs(a) a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); Ok(()) } @@ -425,20 +422,21 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_func\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_func\'") + ); // Can call it if you put quotes let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +--------------+ | MY_FUNC(t.i) | +--------------+ | 1 | +--------------+ - "###); + "); Ok(()) } @@ -469,13 +467,13 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { ctx.register_udf(udf); let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; insta::assert_snapshot!(batches_to_string(&alias_result), @r" @@ -508,10 +506,6 @@ impl AddIndexToStringVolatileScalarUDF { } impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -675,9 +669,6 @@ impl CastToI64UDF { } impl ScalarUDFImpl for CastToI64UDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "cast_to_i64" } @@ -696,7 +687,7 @@ impl ScalarUDFImpl for CastToI64UDF { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // DataFusion should have ensured the function is called with just a // single argument @@ -712,10 +703,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, - }) + Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64)) }; // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) @@ -800,9 +788,6 @@ impl TakeUDF { /// Implement a ScalarUDFImpl whose return type is a function of the input values impl ScalarUDFImpl for TakeUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "take" } @@ -945,13 +930,10 @@ struct ScalarFunctionWrapper { expr: Expr, signature: Signature, return_type: DataType, + defaults: Vec>, } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -971,9 +953,9 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { - let replacement = Self::replacement(&self.expr, &args)?; + let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; Ok(ExprSimplifyResult::Simplified(replacement)) } @@ -981,7 +963,11 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { impl ScalarFunctionWrapper { // replaces placeholders with actual arguments - fn replacement(expr: &Expr, args: &[Expr]) -> Result { + fn replacement( + expr: &Expr, + args: &[Expr], + defaults: &[Option], + ) -> Result { let result = expr.clone().transform(|e| { let r = match e { Expr::Placeholder(placeholder) => { @@ -989,11 +975,19 @@ impl ScalarFunctionWrapper { Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { Transformed::yes(args[placeholder_position].clone()) - } else { + } else if placeholder_position >= defaults.len() { exec_err!( - "Function argument {} not provided, argument missing!", + "Invalid placeholder, out of range: {}", placeholder.id )? + } else { + match defaults[placeholder_position] { + Some(ref default) => Transformed::yes(default.clone()), + None => exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )?, + } } } _ => Transformed::no(e), @@ -1021,6 +1015,32 @@ impl TryFrom for ScalarFunctionWrapper { type Error = DataFusionError; fn try_from(definition: CreateFunction) -> std::result::Result { + let args = definition.args.unwrap_or_default(); + let defaults: Vec> = + args.iter().map(|a| a.default_expr.clone()).collect(); + let signature: Signature = match defaults.iter().position(|v| v.is_some()) { + Some(pos) => { + let mut type_signatures: Vec = vec![]; + // Generate all valid signatures + for n in pos..defaults.len() + 1 { + if n == 0 { + type_signatures.push(TypeSignature::Nullary) + } else { + type_signatures.push(TypeSignature::Exact( + args.iter().take(n).map(|a| a.data_type.clone()).collect(), + )) + } + } + Signature::one_of( + type_signatures, + definition.params.behavior.unwrap_or(Volatility::Volatile), + ) + } + None => Signature::exact( + args.iter().map(|a| a.data_type.clone()).collect(), + definition.params.behavior.unwrap_or(Volatility::Volatile), + ), + }; Ok(Self { name: definition.name, expr: definition @@ -1030,15 +1050,8 @@ impl TryFrom for ScalarFunctionWrapper { return_type: definition .return_type .expect("Return type has to be defined!"), - signature: Signature::exact( - definition - .args - .unwrap_or_default() - .into_iter() - .map(|a| a.data_type) - .collect(), - definition.params.behavior.unwrap_or(Volatility::Volatile), - ), + signature, + defaults, }) } } @@ -1061,10 +1074,11 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { // Create the `better_add` function dynamically via CREATE FUNCTION statement assert!(ctx.sql(sql).await.is_ok()); // try to `drop function` when sql options have allow ddl disabled - assert!(ctx - .sql_with_options("drop function better_add", options) - .await - .is_err()); + assert!( + ctx.sql_with_options("drop function better_add", options) + .await + .is_err() + ); let result = ctx .sql("select better_add(2.0, 2.0)") @@ -1109,6 +1123,175 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { "#; assert!(ctx.sql(bad_definition_sql).await.is_err()); + // FIXME: Definitions with invalid placeholders are allowed, fail at runtime + let bad_expression_sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE) + RETURNS DOUBLE + RETURN $1 + $3 + "#; + assert!(ctx.sql(bad_expression_sql).await.is_ok()); + + let err = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await + .expect_err("unknown placeholder"); + let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3"; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + // cannot mix named and positional style + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $1 + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("cannot mix named and positional style"); + let expected = "Error during planning: All function arguments must use either named or positional style."; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + // Check all function arity supported + let result = ctx.sql("select better_add()").await?.collect().await?; + + assert_batches_eq!( + &[ + "+--------------+", + "| better_add() |", + "+--------------+", + "| 4.0 |", + "+--------------+", + ], + &result + ); + + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err()); + assert!(ctx.sql("drop function better_add").await.is_ok()); + + // works with positional style + let sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + assert!(ctx.sql(sql).await.is_ok()); + + assert!(ctx.sql("select better_add()").await.is_err()); + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + // non-default argument cannot follow default argument + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("non-default argument cannot follow default argument"); + let expected = + "Error during planning: Non-default arguments cannot follow default arguments."; + assert!(expected.starts_with(&err.strip_backtrace())); + + let expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + let result = ctx.sql(expression_sql).await; + + assert!(result.is_ok()); Ok(()) } @@ -1239,10 +1422,6 @@ impl MyRegexUdf { } impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "regex_udf" } @@ -1407,10 +1586,6 @@ impl MetadataBasedUdf { } impl ScalarUDFImpl for MetadataBasedUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -1616,10 +1791,6 @@ impl Default for ExtensionBasedUdf { } } impl ScalarUDFImpl for ExtensionBasedUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -1786,9 +1957,6 @@ async fn test_config_options_work_for_scalar_func() -> Result<()> { } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "TestScalarUDF" } @@ -1850,10 +2018,6 @@ async fn test_extension_metadata_preserve_in_sql_values() -> Result<()> { } impl ScalarUDFImpl for MakeExtension { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "make_extension" } @@ -1931,10 +2095,6 @@ async fn test_extension_metadata_preserve_in_subquery() -> Result<()> { } impl ScalarUDFImpl for ExtensionScalarPredicate { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "extension_predicate" } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 2c6611f382cea..c8ded3a6fce3f 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -21,20 +21,20 @@ use std::path::Path; use std::sync::Arc; use arrow::array::Int64Array; -use arrow::csv::reader::Format; use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::test_util::batches_to_string; -use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::TaskContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::SessionContext; -use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; +use datafusion_catalog::{Session, TableFunctionArgs}; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; @@ -55,7 +55,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ @@ -65,7 +65,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d | | 5 | ETHIOPIA | 0 | ven packages wake quickly. regu | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ - "###); + "); // just run, return all rows let rbs = ctx @@ -74,7 +74,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ @@ -89,7 +89,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull | | 10 | IRAN | 4 | efully alongside of the slyly final dependencies. | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ - "###); + "); Ok(()) } @@ -118,10 +118,6 @@ struct SimpleCsvTable { #[async_trait] impl TableProvider for SimpleCsvTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -200,12 +196,13 @@ impl SimpleCsvTable { struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call_with_args(&self, args: TableFunctionArgs) -> Result> { + let exprs = args.exprs(); let mut new_exprs = vec![]; let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { + Expr::Literal(ScalarValue::Utf8(Some(path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), @@ -221,6 +218,31 @@ impl TableFunctionImpl for SimpleCsvTableFunc { } } +/// Test that expressions passed to UDTFs are properly type-coerced +/// This is a regression test for https://github.com/apache/datafusion/issues/19914 +#[tokio::test] +async fn test_udtf_type_coercion() -> Result<()> { + use datafusion::datasource::MemTable; + + #[derive(Debug)] + struct NoOpTableFunc; + + impl TableFunctionImpl for NoOpTableFunc { + fn call_with_args(&self, _: TableFunctionArgs) -> Result> { + let schema = Arc::new(arrow::datatypes::Schema::empty()); + Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)) + } + } + + let ctx = SessionContext::new(); + ctx.register_udtf("f", Arc::new(NoOpTableFunc)); + + // This should not panic - the array elements should be coerced to Float64 + let _ = ctx.sql("SELECT * FROM f(ARRAY[0.1, 1, 2])").await?; + + Ok(()) +} + fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { let mut file = File::open(csv_path)?; let (schema, _) = Format::default() diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 33607ebc0d2cc..afaf269ca1200 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,8 +19,8 @@ //! user defined window functions use arrow::array::{ - record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, - UInt64Array, + Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, UInt64Array, + record_batch, }; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::FieldRef; @@ -38,17 +38,16 @@ use datafusion_functions_window_common::{ expr::ExpressionArgs, field::WindowUDFFieldArgs, }; use datafusion_physical_expr::{ - expressions::{col, lit}, PhysicalExpr, + expressions::{col, lit}, }; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::{ - any::Any, ops::Range, sync::{ - atomic::{AtomicUsize, Ordering}, Arc, + atomic::{AtomicUsize, Ordering}, }, }; @@ -62,8 +61,7 @@ const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \ from t ORDER BY x, y"; /// A query with a window function evaluated over a moving window -const BOUNDED_WINDOW_QUERY: &str = - "SELECT x, y, val, \ +const BOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ odd_counter(val) OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ from t ORDER BY x, y"; @@ -75,22 +73,22 @@ async fn test_setup() { let sql = "SELECT * from t order by x, y"; let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+ - | x | y | val | - +---+---+-----+ - | 1 | a | 0 | - | 1 | b | 1 | - | 1 | c | 2 | - | 2 | d | 3 | - | 2 | e | 4 | - | 2 | f | 5 | - | 2 | g | 6 | - | 2 | h | 6 | - | 2 | i | 6 | - | 2 | j | 6 | - +---+---+-----+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+ + | x | y | val | + +---+---+-----+ + | 1 | a | 0 | + | 1 | b | 1 | + | 1 | c | 2 | + | 2 | d | 3 | + | 2 | e | 4 | + | 2 | f | 5 | + | 2 | g | 6 | + | 2 | h | 6 | + | 2 | i | 6 | + | 2 | j | 6 | + +---+---+-----+ + "); } /// Basic user defined window function @@ -101,22 +99,22 @@ async fn test_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); // evaluated on two distinct batches assert_eq!(test_state.evaluate_all_called(), 2); @@ -175,22 +173,22 @@ async fn test_udwf_bounded_window_ignores_frame() { // Since the UDWF doesn't say it needs the window frame, the frame is ignored let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // evaluated on 2 distinct batches (when x=1 and x=2) assert_eq!(test_state.evaluate_called(), 0); @@ -205,22 +203,22 @@ async fn test_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -237,22 +235,22 @@ async fn test_stateful_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 0 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 1 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 0 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 1 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 10); assert_eq!(test_state.evaluate_all_called(), 0); @@ -268,22 +266,22 @@ async fn test_stateful_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate and update_state is called for each input row assert_eq!(test_state.evaluate_called(), 10); @@ -298,22 +296,22 @@ async fn test_udwf_query_include_rank() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -329,22 +327,22 @@ async fn test_udwf_bounded_query_include_rank() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -362,22 +360,22 @@ async fn test_udwf_bounded_window_returns_null() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | | - | 2 | i | 6 | | - | 2 | j | 6 | | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | | + | 2 | i | 6 | | + | 2 | j | 6 | | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -537,7 +535,7 @@ impl OddCounter { impl SimpleWindowUDF { fn new(test_state: Arc) -> Self { let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Signature::exact(vec![DataType::Int64], Volatility::Immutable); Self { signature, test_state: test_state.into(), @@ -547,10 +545,6 @@ impl OddCounter { } impl WindowUDFImpl for SimpleWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "odd_counter" } @@ -616,7 +610,9 @@ impl PartitionEvaluator for OddCounter { ranks_in_partition: &[Range], ) -> Result { self.test_state.inc_evaluate_all_with_rank_called(); - println!("evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}"); + println!( + "evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}" + ); // when evaluating with ranks, just return the inverse rank instead let array: Int64Array = ranks_in_partition .iter() @@ -674,10 +670,6 @@ impl VariadicWindowUDF { } impl WindowUDFImpl for VariadicWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "variadic_window_udf" } @@ -818,10 +810,6 @@ impl MetadataBasedWindowUdf { } impl WindowUDFImpl for MetadataBasedWindowUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } diff --git a/datafusion/datasource-arrow/Cargo.toml b/datafusion/datasource-arrow/Cargo.toml index b3d1e3f2accc9..2718e424c6386 100644 --- a/datafusion/datasource-arrow/Cargo.toml +++ b/datafusion/datasource-arrow/Cargo.toml @@ -51,6 +51,9 @@ tokio = { workspace = true } [dev-dependencies] chrono = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -59,6 +62,6 @@ name = "datafusion_datasource_arrow" path = "src/mod.rs" [features] -compression = [ - "arrow-ipc/zstd", -] +# This feature is deprecated, as core functionality in the SpillManager requires all features +# it enabled, and will be removed in a future version. +compression = [] diff --git a/datafusion/datasource-arrow/NOTICE.txt b/datafusion/datasource-arrow/NOTICE.txt index 7f3c80d606c07..0bd2d52368fea 100644 --- a/datafusion/datasource-arrow/NOTICE.txt +++ b/datafusion/datasource-arrow/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2025 The Apache Software Foundation +Copyright 2019-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 3b85640804219..9297486ad66e7 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -19,31 +19,31 @@ //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) -use std::any::Any; -use std::borrow::Cow; use std::collections::HashMap; use std::fmt::{self, Debug}; +use std::io::{Seek, SeekFrom}; use std::sync::Arc; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::ipc::convert::fb_to_schema; -use arrow::ipc::reader::FileReader; +use arrow::ipc::reader::{FileReader, StreamReader}; use arrow::ipc::writer::IpcWriteOptions; -use arrow::ipc::{root_as_message, CompressionType}; +use arrow::ipc::{CompressionType, root_as_message}; use datafusion_common::error::Result; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics, - DEFAULT_ARROW_EXTENSION, + DEFAULT_ARROW_EXTENSION, DataFusionError, GetExt, Statistics, + internal_datafusion_err, not_impl_err, }; use datafusion_common_runtime::{JoinSet, SpawnedTask}; +use datafusion_datasource::TableSchema; use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_datasource::write::{ - get_writer_schema, ObjectWriterBuilder, SharedBuffer, + ObjectWriterBuilder, SharedBuffer, get_writer_schema, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; @@ -59,9 +59,12 @@ use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use futures::stream::BoxStream; use futures::StreamExt; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use futures::stream::BoxStream; +use object_store::{ + GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt, + path::Path, +}; use tokio::io::AsyncWriteExt; /// Initial writing buffer size. Note this is just a size hint for efficiency. It @@ -71,8 +74,8 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +/// Factory struct used to create [`ArrowFormat`] #[derive(Default, Debug)] -/// Factory struct used to create [ArrowFormat] pub struct ArrowFormatFactory; impl ArrowFormatFactory { @@ -94,10 +97,6 @@ impl FileFormatFactory for ArrowFormatFactory { fn default(&self) -> Arc { Arc::new(ArrowFormat) } - - fn as_any(&self) -> &dyn Any { - self - } } impl GetExt for ArrowFormatFactory { @@ -107,16 +106,12 @@ impl GetExt for ArrowFormatFactory { } } -/// Arrow `FileFormat` implementation. +/// Arrow [`FileFormat`] implementation. #[derive(Default, Debug)] pub struct ArrowFormat; #[async_trait] impl FileFormat for ArrowFormat { - fn as_any(&self) -> &dyn Any { - self - } - fn get_ext(&self) -> String { ArrowFormatFactory::new().get_ext() } @@ -150,14 +145,27 @@ impl FileFormat for ArrowFormat { let schema = match r.payload { #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { - let reader = FileReader::try_new(&mut file, None)?; - reader.schema() - } - GetResultPayload::Stream(stream) => { - infer_schema_from_file_stream(stream).await? + match FileReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(file_error) => { + // not in the file format, but FileReader read some bytes + // while trying to parse the file and so we need to rewind + // it to the beginning of the file + file.seek(SeekFrom::Start(0))?; + match StreamReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(stream_error) => { + return Err(internal_datafusion_err!( + "Failed to parse Arrow file as either file format or stream format. File format error: {file_error}. Stream format error: {stream_error}" + )); + } + } + } + } } + GetResultPayload::Stream(stream) => infer_stream_schema(stream).await?, }; - schemas.push(schema.as_ref().clone()); + schemas.push(Arc::unwrap_or_clone(schema)); } let merged_schema = Schema::try_merge(schemas)?; Ok(Arc::new(merged_schema)) @@ -175,10 +183,40 @@ impl FileFormat for ArrowFormat { async fn create_physical_plan( &self, - _state: &dyn Session, + state: &dyn Session, conf: FileScanConfig, ) -> Result> { - let source = Arc::new(ArrowSource::default()); + let object_store = state.runtime_env().object_store(&conf.object_store_url)?; + let object_location = &conf + .file_groups + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .files() + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .object_meta + .location; + + let table_schema = TableSchema::new( + Arc::clone(conf.file_schema()), + conf.table_partition_cols().clone(), + ); + + let mut source: Arc = + match is_object_in_arrow_ipc_file_format(object_store, object_location).await + { + Ok(true) => Arc::new(ArrowSource::new_file_source(table_schema)), + Ok(false) => Arc::new(ArrowSource::new_stream_file_source(table_schema)), + Err(e) => Err(e)?, + }; + + // Preserve projection from the original file source + if let Some(projection) = conf.file_source.projection() + && let Some(new_source) = source.try_pushdown_projection(projection)? + { + source = new_source; + } + let config = FileScanConfigBuilder::from(conf) .with_source(source) .build(); @@ -202,12 +240,12 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } - fn file_source(&self) -> Arc { - Arc::new(ArrowSource::default()) + fn file_source(&self, table_schema: TableSchema) -> Arc { + Arc::new(ArrowSource::new_file_source(table_schema)) } } -/// Implements [`FileSink`] for writing to arrow_ipc files +/// Implements [`FileSink`] for Arrow IPC files struct ArrowFileSink { config: FileSinkConfig, } @@ -327,10 +365,6 @@ impl DisplayAs for ArrowFileSink { #[async_trait] impl DataSink for ArrowFileSink { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> &SchemaRef { self.config.output_schema() } @@ -344,107 +378,178 @@ impl DataSink for ArrowFileSink { } } +// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +// See + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; -/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. -/// See -async fn infer_schema_from_file_stream( +async fn infer_stream_schema( mut stream: BoxStream<'static, object_store::Result>, ) -> Result { - // Expected format: - // - 6 bytes - // - 2 bytes - // - 4 bytes, not present below v0.15.0 - // - 4 bytes - // - // - - // So in first read we need at least all known sized sections, - // which is 6 + 2 + 4 + 4 = 16 bytes. - let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; - - // Files should start with these magic bytes - if bytes[0..6] != ARROW_MAGIC { - return Err(ArrowError::ParseError( - "Arrow file does not contain correct header".to_string(), - ))?; - } - - // Since continuation marker bytes added in later versions - let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { - (&bytes[12..16], 16) + // IPC streaming format. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + // + // + // + // ... + // + // + // ... + // + // ... + // + // ... + // + // + + // The streaming format is made up of a sequence of encapsulated messages. + // See https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format + // + // (added in v0.15.0) + // + // + // + // + // + // The first message is the schema. + + // IPC file format is a wrapper around the streaming format with indexing information. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format + // + // + // + // + //